Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c8a7617536 |
@@ -761,8 +761,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
LayerSkipConfig,
|
LayerSkipConfig,
|
||||||
PyramidAttentionBroadcastConfig,
|
PyramidAttentionBroadcastConfig,
|
||||||
SmoothedEnergyGuidanceConfig,
|
SmoothedEnergyGuidanceConfig,
|
||||||
apply_layer_skip,
|
|
||||||
apply_faster_cache,
|
apply_faster_cache,
|
||||||
|
apply_layer_skip,
|
||||||
apply_pyramid_attention_broadcast,
|
apply_pyramid_attention_broadcast,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -1085,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionSAGPipeline,
|
StableDiffusionSAGPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
StableDiffusionXLAdapterPipeline,
|
StableDiffusionXLAdapterPipeline,
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||||
StableDiffusionXLControlNetInpaintPipeline,
|
StableDiffusionXLControlNetInpaintPipeline,
|
||||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||||
@@ -1102,7 +1103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLPAGInpaintPipeline,
|
StableDiffusionXLPAGInpaintPipeline,
|
||||||
StableDiffusionXLPAGPipeline,
|
StableDiffusionXLPAGPipeline,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
StableUnCLIPImg2ImgPipeline,
|
StableUnCLIPImg2ImgPipeline,
|
||||||
StableUnCLIPPipeline,
|
StableUnCLIPPipeline,
|
||||||
StableVideoDiffusionPipeline,
|
StableVideoDiffusionPipeline,
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -119,19 +120,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
|||||||
def _is_apg_enabled(self) -> bool:
|
def _is_apg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|
||||||
|
|
||||||
@@ -156,25 +157,25 @@ def normalized_guidance(
|
|||||||
):
|
):
|
||||||
diff = pred_cond - pred_uncond
|
diff = pred_cond - pred_uncond
|
||||||
dim = [-i for i in range(1, len(diff.shape))]
|
dim = [-i for i in range(1, len(diff.shape))]
|
||||||
|
|
||||||
if momentum_buffer is not None:
|
if momentum_buffer is not None:
|
||||||
momentum_buffer.update(diff)
|
momentum_buffer.update(diff)
|
||||||
diff = momentum_buffer.running_average
|
diff = momentum_buffer.running_average
|
||||||
|
|
||||||
if norm_threshold > 0:
|
if norm_threshold > 0:
|
||||||
ones = torch.ones_like(diff)
|
ones = torch.ones_like(diff)
|
||||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||||
diff = diff * scale_factor
|
diff = diff * scale_factor
|
||||||
|
|
||||||
v0, v1 = diff.double(), pred_cond.double()
|
v0, v1 = diff.double(), pred_cond.double()
|
||||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||||
v0_orthogonal = v0 - v0_parallel
|
v0_orthogonal = v0 - v0_parallel
|
||||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||||
|
|
||||||
pred = pred_cond if use_original_formulation else pred_uncond
|
pred = pred_cond if use_original_formulation else pred_uncond
|
||||||
pred = pred + guidance_scale * normalized_update
|
pred = pred + guidance_scale * normalized_update
|
||||||
|
|
||||||
return pred
|
return pred
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
|
|||||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -113,13 +114,13 @@ class AutoGuidance(BaseGuidance):
|
|||||||
if self._is_ag_enabled() and self.is_unconditional:
|
if self._is_ag_enabled() and self.is_unconditional:
|
||||||
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
||||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||||
|
|
||||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||||
if self._is_ag_enabled() and self.is_unconditional:
|
if self._is_ag_enabled() and self.is_unconditional:
|
||||||
for name in self._auto_guidance_hook_names:
|
for name in self._auto_guidance_hook_names:
|
||||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||||
registry.remove_hook(name, recurse=True)
|
registry.remove_hook(name, recurse=True)
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
@@ -140,9 +141,9 @@ class AutoGuidance(BaseGuidance):
|
|||||||
|
|
||||||
if self.guidance_rescale > 0.0:
|
if self.guidance_rescale > 0.0:
|
||||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||||
|
|
||||||
return pred, {}
|
return pred, {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_conditional(self) -> bool:
|
def is_conditional(self) -> bool:
|
||||||
return self._count_prepared == 1
|
return self._count_prepared == 1
|
||||||
@@ -157,17 +158,17 @@ class AutoGuidance(BaseGuidance):
|
|||||||
def _is_ag_enabled(self) -> bool:
|
def _is_ag_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ class ClassifierFreeGuidance(BaseGuidance):
|
|||||||
self.guidance_scale = guidance_scale
|
self.guidance_scale = guidance_scale
|
||||||
self.guidance_rescale = guidance_rescale
|
self.guidance_rescale = guidance_rescale
|
||||||
self.use_original_formulation = use_original_formulation
|
self.use_original_formulation = use_original_formulation
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
@@ -112,17 +113,17 @@ class ClassifierFreeGuidance(BaseGuidance):
|
|||||||
def _is_cfg_enabled(self) -> bool:
|
def _is_cfg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
|||||||
self.zero_init_steps = zero_init_steps
|
self.zero_init_steps = zero_init_steps
|
||||||
self.guidance_rescale = guidance_rescale
|
self.guidance_rescale = guidance_rescale
|
||||||
self.use_original_formulation = use_original_formulation
|
self.use_original_formulation = use_original_formulation
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
@@ -102,7 +103,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
|||||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||||
|
|
||||||
return pred, {}
|
return pred, {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_conditional(self) -> bool:
|
def is_conditional(self) -> bool:
|
||||||
return self._count_prepared == 1
|
return self._count_prepared == 1
|
||||||
@@ -117,19 +118,19 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
|||||||
def _is_cfg_enabled(self) -> bool:
|
def _is_cfg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -58,10 +58,10 @@ class BaseGuidance:
|
|||||||
|
|
||||||
def disable(self):
|
def disable(self):
|
||||||
self._enabled = False
|
self._enabled = False
|
||||||
|
|
||||||
def enable(self):
|
def enable(self):
|
||||||
self._enabled = True
|
self._enabled = True
|
||||||
|
|
||||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||||
self._step = step
|
self._step = step
|
||||||
self._num_inference_steps = num_inference_steps
|
self._num_inference_steps = num_inference_steps
|
||||||
@@ -104,14 +104,14 @@ class BaseGuidance:
|
|||||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||||
)
|
)
|
||||||
self._input_fields = kwargs
|
self._input_fields = kwargs
|
||||||
|
|
||||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||||
"""
|
"""
|
||||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||||
subclasses to implement specific model preparation logic.
|
subclasses to implement specific model preparation logic.
|
||||||
"""
|
"""
|
||||||
self._count_prepared += 1
|
self._count_prepared += 1
|
||||||
|
|
||||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||||
"""
|
"""
|
||||||
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
|
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
|
||||||
@@ -119,7 +119,7 @@ class BaseGuidance:
|
|||||||
modifications made during `prepare_models`.
|
modifications made during `prepare_models`.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||||
|
|
||||||
@@ -139,15 +139,15 @@ class BaseGuidance:
|
|||||||
@property
|
@property
|
||||||
def is_conditional(self) -> bool:
|
def is_conditional(self) -> bool:
|
||||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_unconditional(self) -> bool:
|
def is_unconditional(self) -> bool:
|
||||||
return not self.is_conditional
|
return not self.is_conditional
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_conditions(self) -> int:
|
def num_conditions(self) -> int:
|
||||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
|
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
|
|||||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -148,14 +149,14 @@ class SkipLayerGuidance(BaseGuidance):
|
|||||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||||
|
|
||||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||||
# Remove the hooks after inference
|
# Remove the hooks after inference
|
||||||
for hook_name in self._skip_layer_hook_names:
|
for hook_name in self._skip_layer_hook_names:
|
||||||
registry.remove_hook(hook_name, recurse=True)
|
registry.remove_hook(hook_name, recurse=True)
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||||
if self.num_conditions == 1:
|
if self.num_conditions == 1:
|
||||||
tuple_indices = [0]
|
tuple_indices = [0]
|
||||||
@@ -200,7 +201,7 @@ class SkipLayerGuidance(BaseGuidance):
|
|||||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||||
|
|
||||||
return pred, {}
|
return pred, {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_conditional(self) -> bool:
|
def is_conditional(self) -> bool:
|
||||||
return self._count_prepared == 1 or self._count_prepared == 3
|
return self._count_prepared == 1 or self._count_prepared == 3
|
||||||
@@ -217,31 +218,31 @@ class SkipLayerGuidance(BaseGuidance):
|
|||||||
def _is_cfg_enabled(self) -> bool:
|
def _is_cfg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|
||||||
def _is_slg_enabled(self) -> bool:
|
def _is_slg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||||
|
|
||||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||||
|
|
||||||
return is_within_range and not is_zero
|
return is_within_range and not is_zero
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
|
|||||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -141,14 +142,14 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
|||||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||||
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
||||||
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
||||||
|
|
||||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||||
# Remove the hooks after inference
|
# Remove the hooks after inference
|
||||||
for hook_name in self._seg_layer_hook_names:
|
for hook_name in self._seg_layer_hook_names:
|
||||||
registry.remove_hook(hook_name, recurse=True)
|
registry.remove_hook(hook_name, recurse=True)
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||||
if self.num_conditions == 1:
|
if self.num_conditions == 1:
|
||||||
tuple_indices = [0]
|
tuple_indices = [0]
|
||||||
@@ -193,7 +194,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
|||||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||||
|
|
||||||
return pred, {}
|
return pred, {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_conditional(self) -> bool:
|
def is_conditional(self) -> bool:
|
||||||
return self._count_prepared == 1 or self._count_prepared == 3
|
return self._count_prepared == 1 or self._count_prepared == 3
|
||||||
@@ -210,31 +211,31 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
|||||||
def _is_cfg_enabled(self) -> bool:
|
def _is_cfg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|
||||||
def _is_seg_enabled(self) -> bool:
|
def _is_seg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||||
|
|
||||||
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
||||||
|
|
||||||
return is_within_range and not is_zero
|
return is_within_range and not is_zero
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
@@ -97,24 +98,24 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
|||||||
def _is_tcfg_enabled(self) -> bool:
|
def _is_tcfg_enabled(self) -> bool:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_within_range = True
|
is_within_range = True
|
||||||
if self._num_inference_steps is not None:
|
if self._num_inference_steps is not None:
|
||||||
skip_start_step = int(self._start * self._num_inference_steps)
|
skip_start_step = int(self._start * self._num_inference_steps)
|
||||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||||
|
|
||||||
is_close = False
|
is_close = False
|
||||||
if self.use_original_formulation:
|
if self.use_original_formulation:
|
||||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||||
else:
|
else:
|
||||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||||
|
|
||||||
return is_within_range and not is_close
|
return is_within_range and not is_close
|
||||||
|
|
||||||
|
|
||||||
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
|
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
|
||||||
cond_dtype = pred_cond.dtype
|
cond_dtype = pred_cond.dtype
|
||||||
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
||||||
preds = preds.flatten(2)
|
preds = preds.flatten(2)
|
||||||
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
||||||
@@ -125,9 +126,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid
|
|||||||
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
||||||
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
||||||
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
||||||
|
|
||||||
pred = pred_cond if use_original_formulation else pred_uncond
|
pred = pred_cond if use_original_formulation else pred_uncond
|
||||||
shift = pred_cond - pred_uncond
|
shift = pred_cond - pred_uncond
|
||||||
pred = pred + guidance_scale * shift
|
pred = pred + guidance_scale * shift
|
||||||
|
|
||||||
return pred
|
return pred
|
||||||
|
|||||||
@@ -20,7 +20,12 @@ import torch
|
|||||||
|
|
||||||
from ..utils import get_logger
|
from ..utils import get_logger
|
||||||
from ..utils.torch_utils import unwrap_module
|
from ..utils.torch_utils import unwrap_module
|
||||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
|
from ._common import (
|
||||||
|
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
_ATTENTION_CLASSES,
|
||||||
|
_FEEDFORWARD_CLASSES,
|
||||||
|
_get_submodule_from_fqn,
|
||||||
|
)
|
||||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||||
from .hooks import HookRegistry, ModelHook
|
from .hooks import HookRegistry, ModelHook
|
||||||
|
|
||||||
@@ -196,15 +201,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
|
|||||||
for i, block in enumerate(transformer_blocks):
|
for i, block in enumerate(transformer_blocks):
|
||||||
if i not in config.indices:
|
if i not in config.indices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
blocks_found = True
|
blocks_found = True
|
||||||
|
|
||||||
if config.skip_attention and config.skip_ff:
|
if config.skip_attention and config.skip_ff:
|
||||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||||
hook = TransformerBlockSkipHook(config.dropout)
|
hook = TransformerBlockSkipHook(config.dropout)
|
||||||
registry.register_hook(hook, name)
|
registry.register_hook(hook, name)
|
||||||
|
|
||||||
elif config.skip_attention or config.skip_attention_scores:
|
elif config.skip_attention or config.skip_attention_scores:
|
||||||
for submodule_name, submodule in block.named_modules():
|
for submodule_name, submodule in block.named_modules():
|
||||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||||
@@ -213,7 +218,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
|
|||||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
||||||
registry.register_hook(hook, name)
|
registry.register_hook(hook, name)
|
||||||
|
|
||||||
if config.skip_ff:
|
if config.skip_ff:
|
||||||
for submodule_name, submodule in block.named_modules():
|
for submodule_name, submodule in block.named_modules():
|
||||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -67,7 +67,7 @@ class SmoothedEnergyGuidanceHook(ModelHook):
|
|||||||
|
|
||||||
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
|
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
|
||||||
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
||||||
|
|
||||||
if config.fqn == "auto":
|
if config.fqn == "auto":
|
||||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||||
if hasattr(module, identifier):
|
if hasattr(module, identifier):
|
||||||
@@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
|
|||||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||||
)
|
)
|
||||||
|
|
||||||
if config._query_proj_identifiers is None:
|
if config._query_proj_identifiers is None:
|
||||||
config._query_proj_identifiers = ["to_q"]
|
config._query_proj_identifiers = ["to_q"]
|
||||||
|
|
||||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||||
blocks_found = False
|
blocks_found = False
|
||||||
for i, block in enumerate(transformer_blocks):
|
for i, block in enumerate(transformer_blocks):
|
||||||
if i not in config.indices:
|
if i not in config.indices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
blocks_found = True
|
blocks_found = True
|
||||||
|
|
||||||
for submodule_name, submodule in block.named_modules():
|
for submodule_name, submodule in block.named_modules():
|
||||||
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
||||||
continue
|
continue
|
||||||
@@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
|
|||||||
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
||||||
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
||||||
registry.register_hook(hook, name)
|
registry.register_hook(hook, name)
|
||||||
|
|
||||||
if not blocks_found:
|
if not blocks_found:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||||
@@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
|||||||
in the future without warning or guarantee of reproducibility.
|
in the future without warning or guarantee of reproducibility.
|
||||||
"""
|
"""
|
||||||
assert query.ndim == 3
|
assert query.ndim == 3
|
||||||
|
|
||||||
is_inf = sigma > sigma_threshold_inf
|
is_inf = sigma > sigma_threshold_inf
|
||||||
batch_size, seq_len, embed_dim = query.shape
|
batch_size, seq_len, embed_dim = query.shape
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
|||||||
query_slice = query[:, :num_square_tokens, :]
|
query_slice = query[:, :num_square_tokens, :]
|
||||||
query_slice = query_slice.permute(0, 2, 1)
|
query_slice = query_slice.permute(0, 2, 1)
|
||||||
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
||||||
|
|
||||||
if is_inf:
|
if is_inf:
|
||||||
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
||||||
kernel_size_half = (kernel_size - 1) / 2
|
kernel_size_half = (kernel_size - 1) / 2
|
||||||
@@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
|||||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||||
query_slice = query_slice.permute(0, 2, 1)
|
query_slice = query_slice.permute(0, 2, 1)
|
||||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|||||||
@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .ip_adapter import (
|
from .ip_adapter import (
|
||||||
FluxIPAdapterMixin,
|
FluxIPAdapterMixin,
|
||||||
IPAdapterMixin,
|
IPAdapterMixin,
|
||||||
SD3IPAdapterMixin,
|
|
||||||
ModularIPAdapterMixin,
|
ModularIPAdapterMixin,
|
||||||
|
SD3IPAdapterMixin,
|
||||||
)
|
)
|
||||||
from .lora_pipeline import (
|
from .lora_pipeline import (
|
||||||
AmusedLoraLoaderMixin,
|
AmusedLoraLoaderMixin,
|
||||||
|
|||||||
@@ -703,12 +703,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||||
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||||
from .stable_diffusion_xl import (
|
from .stable_diffusion_xl import (
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
StableDiffusionXLImg2ImgPipeline,
|
StableDiffusionXLImg2ImgPipeline,
|
||||||
StableDiffusionXLInpaintPipeline,
|
StableDiffusionXLInpaintPipeline,
|
||||||
StableDiffusionXLInstructPix2PixPipeline,
|
StableDiffusionXLInstructPix2PixPipeline,
|
||||||
StableDiffusionXLModularLoader,
|
StableDiffusionXLModularLoader,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
)
|
)
|
||||||
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
||||||
from .t2i_adapter import (
|
from .t2i_adapter import (
|
||||||
|
|||||||
@@ -12,21 +12,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import Any, Dict, List, Optional, Union
|
||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..models.modeling_utils import ModelMixin
|
|
||||||
from .modular_pipeline_utils import ComponentSpec
|
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
@@ -231,17 +228,18 @@ class AutoOffloadStrategy:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
from .modular_pipeline_utils import ComponentSpec
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
class ComponentsManager:
|
class ComponentsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.components = OrderedDict()
|
self.components = OrderedDict()
|
||||||
self.added_time = OrderedDict() # Store when components were added
|
self.added_time = OrderedDict() # Store when components were added
|
||||||
self.collections = OrderedDict() # collection_name -> set of component_names
|
self.collections = OrderedDict() # collection_name -> set of component_names
|
||||||
self.model_hooks = None
|
self.model_hooks = None
|
||||||
self._auto_offload_enabled = False
|
self._auto_offload_enabled = False
|
||||||
|
|
||||||
|
|
||||||
def _get_by_collection(self, collection: str):
|
def _get_by_collection(self, collection: str):
|
||||||
"""
|
"""
|
||||||
Select components by collection name.
|
Select components by collection name.
|
||||||
@@ -252,8 +250,8 @@ class ComponentsManager:
|
|||||||
for component_id in component_ids:
|
for component_id in component_ids:
|
||||||
selected_components[component_id] = self.components[component_id]
|
selected_components[component_id] = self.components[component_id]
|
||||||
return selected_components
|
return selected_components
|
||||||
|
|
||||||
|
|
||||||
def _get_by_load_id(self, load_id: str):
|
def _get_by_load_id(self, load_id: str):
|
||||||
"""
|
"""
|
||||||
Select components by its load_id.
|
Select components by its load_id.
|
||||||
@@ -263,8 +261,8 @@ class ComponentsManager:
|
|||||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
||||||
selected_components[name] = component
|
selected_components[name] = component
|
||||||
return selected_components
|
return selected_components
|
||||||
|
|
||||||
|
|
||||||
def add(self, name, component, collection: Optional[str] = None):
|
def add(self, name, component, collection: Optional[str] = None):
|
||||||
|
|
||||||
for comp_id, comp in self.components.items():
|
for comp_id, comp in self.components.items():
|
||||||
@@ -282,7 +280,7 @@ class ComponentsManager:
|
|||||||
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||||
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
|
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# add component to components manager
|
# add component to components manager
|
||||||
self.components[component_id] = component
|
self.components[component_id] = component
|
||||||
@@ -293,8 +291,8 @@ class ComponentsManager:
|
|||||||
self.collections[collection].add(component_id)
|
self.collections[collection].add(component_id)
|
||||||
|
|
||||||
if self._auto_offload_enabled:
|
if self._auto_offload_enabled:
|
||||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||||
|
|
||||||
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
|
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
|
||||||
return component_id
|
return component_id
|
||||||
|
|
||||||
@@ -304,14 +302,14 @@ class ComponentsManager:
|
|||||||
if name not in self.components:
|
if name not in self.components:
|
||||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.components.pop(name)
|
self.components.pop(name)
|
||||||
self.added_time.pop(name)
|
self.added_time.pop(name)
|
||||||
|
|
||||||
for collection in self.collections:
|
for collection in self.collections:
|
||||||
if name in self.collections[collection]:
|
if name in self.collections[collection]:
|
||||||
self.collections[collection].remove(name)
|
self.collections[collection].remove(name)
|
||||||
|
|
||||||
if self._auto_offload_enabled:
|
if self._auto_offload_enabled:
|
||||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||||
|
|
||||||
@@ -341,7 +339,7 @@ class ComponentsManager:
|
|||||||
Dictionary mapping component IDs to components,
|
Dictionary mapping component IDs to components,
|
||||||
or list of (base_name, component) tuples if as_name_component_tuples=True
|
or list of (base_name, component) tuples if as_name_component_tuples=True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if collection:
|
if collection:
|
||||||
if collection not in self.collections:
|
if collection not in self.collections:
|
||||||
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
||||||
@@ -360,16 +358,16 @@ class ComponentsManager:
|
|||||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||||
return '_'.join(parts[:-1])
|
return '_'.join(parts[:-1])
|
||||||
return component_id
|
return component_id
|
||||||
|
|
||||||
if names is None:
|
if names is None:
|
||||||
if as_name_component_tuples:
|
if as_name_component_tuples:
|
||||||
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
|
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
|
||||||
else:
|
else:
|
||||||
return components
|
return components
|
||||||
|
|
||||||
# Create mapping from component_id to base_name for all components
|
# Create mapping from component_id to base_name for all components
|
||||||
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
||||||
|
|
||||||
def matches_pattern(component_id, pattern, exact_match=False):
|
def matches_pattern(component_id, pattern, exact_match=False):
|
||||||
"""
|
"""
|
||||||
Helper function to check if a component matches a pattern based on its base name.
|
Helper function to check if a component matches a pattern based on its base name.
|
||||||
@@ -380,124 +378,124 @@ class ComponentsManager:
|
|||||||
exact_match: If True, only exact matches to base_name are considered
|
exact_match: If True, only exact matches to base_name are considered
|
||||||
"""
|
"""
|
||||||
base_name = base_names[component_id]
|
base_name = base_names[component_id]
|
||||||
|
|
||||||
# Exact match with base name
|
# Exact match with base name
|
||||||
if exact_match:
|
if exact_match:
|
||||||
return pattern == base_name
|
return pattern == base_name
|
||||||
|
|
||||||
# Prefix match (ends with *)
|
# Prefix match (ends with *)
|
||||||
elif pattern.endswith('*'):
|
elif pattern.endswith('*'):
|
||||||
prefix = pattern[:-1]
|
prefix = pattern[:-1]
|
||||||
return base_name.startswith(prefix)
|
return base_name.startswith(prefix)
|
||||||
|
|
||||||
# Contains match (starts with *)
|
# Contains match (starts with *)
|
||||||
elif pattern.startswith('*'):
|
elif pattern.startswith('*'):
|
||||||
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
|
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
|
||||||
return search in base_name
|
return search in base_name
|
||||||
|
|
||||||
# Exact match (no wildcards)
|
# Exact match (no wildcards)
|
||||||
else:
|
else:
|
||||||
return pattern == base_name
|
return pattern == base_name
|
||||||
|
|
||||||
if isinstance(names, str):
|
if isinstance(names, str):
|
||||||
# Check if this is a "not" pattern
|
# Check if this is a "not" pattern
|
||||||
is_not_pattern = names.startswith('!')
|
is_not_pattern = names.startswith('!')
|
||||||
if is_not_pattern:
|
if is_not_pattern:
|
||||||
names = names[1:] # Remove the ! prefix
|
names = names[1:] # Remove the ! prefix
|
||||||
|
|
||||||
# Handle OR patterns (containing |)
|
# Handle OR patterns (containing |)
|
||||||
if '|' in names:
|
if '|' in names:
|
||||||
terms = names.split('|')
|
terms = names.split('|')
|
||||||
matches = {}
|
matches = {}
|
||||||
|
|
||||||
for comp_id, comp in components.items():
|
for comp_id, comp in components.items():
|
||||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||||
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
|
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
|
||||||
|
|
||||||
# Check if any of the terms match this component
|
# Check if any of the terms match this component
|
||||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||||
|
|
||||||
# Flip the decision if this is a NOT pattern
|
# Flip the decision if this is a NOT pattern
|
||||||
if is_not_pattern:
|
if is_not_pattern:
|
||||||
should_include = not should_include
|
should_include = not should_include
|
||||||
|
|
||||||
if should_include:
|
if should_include:
|
||||||
matches[comp_id] = comp
|
matches[comp_id] = comp
|
||||||
|
|
||||||
log_msg = "NOT " if is_not_pattern else ""
|
log_msg = "NOT " if is_not_pattern else ""
|
||||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||||
|
|
||||||
# Try exact match with a base name
|
# Try exact match with a base name
|
||||||
elif any(names == base_name for base_name in base_names.values()):
|
elif any(names == base_name for base_name in base_names.values()):
|
||||||
# Find all components with this base name
|
# Find all components with this base name
|
||||||
matches = {
|
matches = {
|
||||||
comp_id: comp for comp_id, comp in components.items()
|
comp_id: comp for comp_id, comp in components.items()
|
||||||
if (base_names[comp_id] == names) != is_not_pattern
|
if (base_names[comp_id] == names) != is_not_pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_not_pattern:
|
if is_not_pattern:
|
||||||
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||||
|
|
||||||
# Prefix match (ends with *)
|
# Prefix match (ends with *)
|
||||||
elif names.endswith('*'):
|
elif names.endswith('*'):
|
||||||
prefix = names[:-1]
|
prefix = names[:-1]
|
||||||
matches = {
|
matches = {
|
||||||
comp_id: comp for comp_id, comp in components.items()
|
comp_id: comp for comp_id, comp in components.items()
|
||||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||||
}
|
}
|
||||||
if is_not_pattern:
|
if is_not_pattern:
|
||||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||||
|
|
||||||
# Contains match (starts with *)
|
# Contains match (starts with *)
|
||||||
elif names.startswith('*'):
|
elif names.startswith('*'):
|
||||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||||
matches = {
|
matches = {
|
||||||
comp_id: comp for comp_id, comp in components.items()
|
comp_id: comp for comp_id, comp in components.items()
|
||||||
if (search in base_names[comp_id]) != is_not_pattern
|
if (search in base_names[comp_id]) != is_not_pattern
|
||||||
}
|
}
|
||||||
if is_not_pattern:
|
if is_not_pattern:
|
||||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||||
|
|
||||||
# Substring match (no wildcards, but not an exact component name)
|
# Substring match (no wildcards, but not an exact component name)
|
||||||
elif any(names in base_name for base_name in base_names.values()):
|
elif any(names in base_name for base_name in base_names.values()):
|
||||||
matches = {
|
matches = {
|
||||||
comp_id: comp for comp_id, comp in components.items()
|
comp_id: comp for comp_id, comp in components.items()
|
||||||
if (names in base_names[comp_id]) != is_not_pattern
|
if (names in base_names[comp_id]) != is_not_pattern
|
||||||
}
|
}
|
||||||
if is_not_pattern:
|
if is_not_pattern:
|
||||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
raise ValueError(f"No components found matching pattern '{names}'")
|
raise ValueError(f"No components found matching pattern '{names}'")
|
||||||
|
|
||||||
if as_name_component_tuples:
|
if as_name_component_tuples:
|
||||||
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
|
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
|
||||||
else:
|
else:
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
elif isinstance(names, list):
|
elif isinstance(names, list):
|
||||||
results = {}
|
results = {}
|
||||||
for name in names:
|
for name in names:
|
||||||
result = self.get(name, collection, load_id, as_name_component_tuples=False)
|
result = self.get(name, collection, load_id, as_name_component_tuples=False)
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
|
||||||
if as_name_component_tuples:
|
if as_name_component_tuples:
|
||||||
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
|
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
|
||||||
else:
|
else:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||||
|
|
||||||
@@ -558,14 +556,14 @@ class ComponentsManager:
|
|||||||
raise ValueError(f"Component '{name}' not found in ComponentsManager")
|
raise ValueError(f"Component '{name}' not found in ComponentsManager")
|
||||||
|
|
||||||
component = self.components[name]
|
component = self.components[name]
|
||||||
|
|
||||||
# Build complete info dict first
|
# Build complete info dict first
|
||||||
info = {
|
info = {
|
||||||
"model_id": name,
|
"model_id": name,
|
||||||
"added_time": self.added_time[name],
|
"added_time": self.added_time[name],
|
||||||
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
|
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Additional info for torch.nn.Module components
|
# Additional info for torch.nn.Module components
|
||||||
if isinstance(component, torch.nn.Module):
|
if isinstance(component, torch.nn.Module):
|
||||||
# Check for hook information
|
# Check for hook information
|
||||||
@@ -573,7 +571,7 @@ class ComponentsManager:
|
|||||||
execution_device = None
|
execution_device = None
|
||||||
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
||||||
execution_device = component._hf_hook.execution_device
|
execution_device = component._hf_hook.execution_device
|
||||||
|
|
||||||
info.update({
|
info.update({
|
||||||
"class_name": component.__class__.__name__,
|
"class_name": component.__class__.__name__,
|
||||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||||
@@ -594,8 +592,8 @@ class ComponentsManager:
|
|||||||
if any("IPAdapter" in ptype for ptype in processor_types):
|
if any("IPAdapter" in ptype for ptype in processor_types):
|
||||||
# Then get scales only from IP-Adapter processors
|
# Then get scales only from IP-Adapter processors
|
||||||
scales = {
|
scales = {
|
||||||
k: v.scale
|
k: v.scale
|
||||||
for k, v in processors.items()
|
for k, v in processors.items()
|
||||||
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
||||||
}
|
}
|
||||||
if scales:
|
if scales:
|
||||||
@@ -609,7 +607,7 @@ class ComponentsManager:
|
|||||||
else:
|
else:
|
||||||
# List of fields requested, return dict with just those fields
|
# List of fields requested, return dict with just those fields
|
||||||
return {k: v for k, v in info.items() if k in fields}
|
return {k: v for k, v in info.items() if k in fields}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -622,13 +620,13 @@ class ComponentsManager:
|
|||||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||||
return '_'.join(parts[:-1])
|
return '_'.join(parts[:-1])
|
||||||
return name
|
return name
|
||||||
|
|
||||||
# Extract load_id if available
|
# Extract load_id if available
|
||||||
def get_load_id(component):
|
def get_load_id(component):
|
||||||
if hasattr(component, "_diffusers_load_id"):
|
if hasattr(component, "_diffusers_load_id"):
|
||||||
return component._diffusers_load_id
|
return component._diffusers_load_id
|
||||||
return "N/A"
|
return "N/A"
|
||||||
|
|
||||||
# Format device info compactly
|
# Format device info compactly
|
||||||
def format_device(component, info):
|
def format_device(component, info):
|
||||||
if not info["has_hook"]:
|
if not info["has_hook"]:
|
||||||
@@ -637,24 +635,24 @@ class ComponentsManager:
|
|||||||
device = str(getattr(component, 'device', 'N/A'))
|
device = str(getattr(component, 'device', 'N/A'))
|
||||||
exec_device = str(info['execution_device'] or 'N/A')
|
exec_device = str(info['execution_device'] or 'N/A')
|
||||||
return f"{device}({exec_device})"
|
return f"{device}({exec_device})"
|
||||||
|
|
||||||
# Get all simple names to calculate width
|
# Get all simple names to calculate width
|
||||||
simple_names = [get_simple_name(id) for id in self.components.keys()]
|
simple_names = [get_simple_name(id) for id in self.components.keys()]
|
||||||
|
|
||||||
# Get max length of load_ids for models
|
# Get max length of load_ids for models
|
||||||
load_ids = [
|
load_ids = [
|
||||||
get_load_id(component)
|
get_load_id(component)
|
||||||
for component in self.components.values()
|
for component in self.components.values()
|
||||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
||||||
]
|
]
|
||||||
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
||||||
|
|
||||||
# Collection names
|
# Collection names
|
||||||
collection_names = [
|
collection_names = [
|
||||||
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
|
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
|
||||||
for name in self.components.keys()
|
for name in self.components.keys()
|
||||||
]
|
]
|
||||||
|
|
||||||
col_widths = {
|
col_widths = {
|
||||||
"name": max(15, max(len(name) for name in simple_names)),
|
"name": max(15, max(len(name) for name in simple_names)),
|
||||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||||
@@ -692,7 +690,7 @@ class ComponentsManager:
|
|||||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||||
load_id = get_load_id(component)
|
load_id = get_load_id(component)
|
||||||
collection = info["collection"] or "N/A"
|
collection = info["collection"] or "N/A"
|
||||||
|
|
||||||
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||||
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
||||||
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
|
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
|
||||||
@@ -712,7 +710,7 @@ class ComponentsManager:
|
|||||||
info = self.get_model_info(name)
|
info = self.get_model_info(name)
|
||||||
simple_name = get_simple_name(name)
|
simple_name = get_simple_name(name)
|
||||||
collection = info["collection"] or "N/A"
|
collection = info["collection"] or "N/A"
|
||||||
|
|
||||||
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
|
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
|
||||||
output += dash_line
|
output += dash_line
|
||||||
|
|
||||||
@@ -726,9 +724,9 @@ class ComponentsManager:
|
|||||||
if info.get("adapters") is not None:
|
if info.get("adapters") is not None:
|
||||||
output += f" Adapters: {info['adapters']}\n"
|
output += f" Adapters: {info['adapters']}\n"
|
||||||
if info.get("ip_adapter"):
|
if info.get("ip_adapter"):
|
||||||
output += f" IP-Adapter: Enabled\n"
|
output += " IP-Adapter: Enabled\n"
|
||||||
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||||
@@ -759,13 +757,13 @@ class ComponentsManager:
|
|||||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
for name, component in pipe.components.items():
|
for name, component in pipe.components.items():
|
||||||
|
|
||||||
if component is None:
|
if component is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add prefix if specified
|
# Add prefix if specified
|
||||||
component_name = f"{prefix}_{name}" if prefix else name
|
component_name = f"{prefix}_{name}" if prefix else name
|
||||||
|
|
||||||
if component_name not in self.components:
|
if component_name not in self.components:
|
||||||
self.add(component_name, component)
|
self.add(component_name, component)
|
||||||
else:
|
else:
|
||||||
@@ -791,13 +789,13 @@ class ComponentsManager:
|
|||||||
ValueError: If no components match or multiple components match
|
ValueError: If no components match or multiple components match
|
||||||
"""
|
"""
|
||||||
results = self.get(name, collection, load_id)
|
results = self.get(name, collection, load_id)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
raise ValueError(f"No components found matching '{name}'")
|
raise ValueError(f"No components found matching '{name}'")
|
||||||
|
|
||||||
if len(results) > 1:
|
if len(results) > 1:
|
||||||
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
||||||
|
|
||||||
return next(iter(results.values()))
|
return next(iter(results.values()))
|
||||||
|
|
||||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@@ -823,17 +821,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
if value_tuple not in value_to_keys:
|
if value_tuple not in value_to_keys:
|
||||||
value_to_keys[value_tuple] = []
|
value_to_keys[value_tuple] = []
|
||||||
value_to_keys[value_tuple].append(key)
|
value_to_keys[value_tuple].append(key)
|
||||||
|
|
||||||
def find_common_prefix(keys: List[str]) -> str:
|
def find_common_prefix(keys: List[str]) -> str:
|
||||||
"""Find the shortest common prefix among a list of dot-separated keys."""
|
"""Find the shortest common prefix among a list of dot-separated keys."""
|
||||||
if not keys:
|
if not keys:
|
||||||
return ""
|
return ""
|
||||||
if len(keys) == 1:
|
if len(keys) == 1:
|
||||||
return keys[0]
|
return keys[0]
|
||||||
|
|
||||||
# Split all keys into parts
|
# Split all keys into parts
|
||||||
key_parts = [k.split('.') for k in keys]
|
key_parts = [k.split('.') for k in keys]
|
||||||
|
|
||||||
# Find how many initial parts are common
|
# Find how many initial parts are common
|
||||||
common_length = 0
|
common_length = 0
|
||||||
for parts in zip(*key_parts):
|
for parts in zip(*key_parts):
|
||||||
@@ -841,10 +839,10 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
common_length += 1
|
common_length += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
if common_length == 0:
|
if common_length == 0:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Return the common prefix
|
# Return the common prefix
|
||||||
return '.'.join(key_parts[0][:common_length])
|
return '.'.join(key_parts[0][:common_length])
|
||||||
|
|
||||||
@@ -858,5 +856,5 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
summary[prefix] = value
|
summary[prefix] = value
|
||||||
else:
|
else:
|
||||||
summary[""] = value # Use empty string if no common prefix
|
summary[""] = value # Use empty string if no common prefix
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -12,13 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass, asdict, field, fields
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||||
from ..utils.import_utils import is_torch_available
|
from ..utils.import_utils import is_torch_available
|
||||||
from ..configuration_utils import FrozenDict, ConfigMixin
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -56,50 +57,50 @@ class ComponentSpec:
|
|||||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||||
|
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||||
return hash((self.name, self.load_id, self.default_creation_method))
|
return hash((self.name, self.load_id, self.default_creation_method))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
"""Compare ComponentSpec objects based on name and load_id."""
|
"""Compare ComponentSpec objects based on name and load_id."""
|
||||||
if not isinstance(other, ComponentSpec):
|
if not isinstance(other, ComponentSpec):
|
||||||
return False
|
return False
|
||||||
return (self.name == other.name and
|
return (self.name == other.name and
|
||||||
self.load_id == other.load_id and
|
self.load_id == other.load_id and
|
||||||
self.default_creation_method == other.default_creation_method)
|
self.default_creation_method == other.default_creation_method)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
|
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
|
||||||
"""Create a ComponentSpec from a Component created by `create` method."""
|
"""Create a ComponentSpec from a Component created by `create` method."""
|
||||||
|
|
||||||
if not hasattr(component, "_diffusers_load_id"):
|
if not hasattr(component, "_diffusers_load_id"):
|
||||||
raise ValueError("Component is not created by `create` method")
|
raise ValueError("Component is not created by `create` method")
|
||||||
|
|
||||||
type_hint = component.__class__
|
type_hint = component.__class__
|
||||||
|
|
||||||
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
|
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
|
||||||
config = component.config
|
config = component.config
|
||||||
else:
|
else:
|
||||||
config = None
|
config = None
|
||||||
|
|
||||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||||
|
|
||||||
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
|
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
|
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
|
||||||
"""Create a ComponentSpec from a load_id string."""
|
"""Create a ComponentSpec from a load_id string."""
|
||||||
if load_id == "null":
|
if load_id == "null":
|
||||||
raise ValueError("Cannot create ComponentSpec from null load_id")
|
raise ValueError("Cannot create ComponentSpec from null load_id")
|
||||||
|
|
||||||
# Decode the load_id into a dictionary of loading fields
|
# Decode the load_id into a dictionary of loading fields
|
||||||
load_fields = cls.decode_load_id(load_id)
|
load_fields = cls.decode_load_id(load_id)
|
||||||
|
|
||||||
# Create a new ComponentSpec instance with the decoded fields
|
# Create a new ComponentSpec instance with the decoded fields
|
||||||
return cls(name=name, **load_fields)
|
return cls(name=name, **load_fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def loading_fields(cls) -> List[str]:
|
def loading_fields(cls) -> List[str]:
|
||||||
"""
|
"""
|
||||||
@@ -107,8 +108,8 @@ class ComponentSpec:
|
|||||||
(i.e. those whose field.metadata["loading"] is True).
|
(i.e. those whose field.metadata["loading"] is True).
|
||||||
"""
|
"""
|
||||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def load_id(self) -> str:
|
def load_id(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -118,7 +119,7 @@ class ComponentSpec:
|
|||||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||||
parts = ["null" if p is None else p for p in parts]
|
parts = ["null" if p is None else p for p in parts]
|
||||||
return "|".join(p for p in parts if p)
|
return "|".join(p for p in parts if p)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
@@ -139,29 +140,29 @@ class ComponentSpec:
|
|||||||
If a segment value is "null", it's replaced with None.
|
If a segment value is "null", it's replaced with None.
|
||||||
Returns None if load_id is "null" (indicating component not loaded from pretrained).
|
Returns None if load_id is "null" (indicating component not loaded from pretrained).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get all loading fields in order
|
# Get all loading fields in order
|
||||||
loading_fields = cls.loading_fields()
|
loading_fields = cls.loading_fields()
|
||||||
result = {f: None for f in loading_fields}
|
result = {f: None for f in loading_fields}
|
||||||
|
|
||||||
if load_id == "null":
|
if load_id == "null":
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Split the load_id
|
# Split the load_id
|
||||||
parts = load_id.split("|")
|
parts = load_id.split("|")
|
||||||
|
|
||||||
# Map parts to loading fields by position
|
# Map parts to loading fields by position
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
if i < len(loading_fields):
|
if i < len(loading_fields):
|
||||||
# Convert "null" string back to None
|
# Convert "null" string back to None
|
||||||
result[loading_fields[i]] = None if part == "null" else part
|
result[loading_fields[i]] = None if part == "null" else part
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# YiYi TODO: add validator
|
# YiYi TODO: add validator
|
||||||
def create(self, **kwargs) -> Any:
|
def create(self, **kwargs) -> Any:
|
||||||
"""Create the component using the preferred creation method."""
|
"""Create the component using the preferred creation method."""
|
||||||
|
|
||||||
# from_pretrained creation
|
# from_pretrained creation
|
||||||
if self.default_creation_method == "from_pretrained":
|
if self.default_creation_method == "from_pretrained":
|
||||||
return self.create_from_pretrained(**kwargs)
|
return self.create_from_pretrained(**kwargs)
|
||||||
@@ -170,17 +171,17 @@ class ComponentSpec:
|
|||||||
return self.create_from_config(**kwargs)
|
return self.create_from_config(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
|
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
|
||||||
|
|
||||||
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
||||||
"""Create component using from_config with config."""
|
"""Create component using from_config with config."""
|
||||||
|
|
||||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`type_hint` is required when using from_config creation method."
|
"`type_hint` is required when using from_config creation method."
|
||||||
)
|
)
|
||||||
|
|
||||||
config = config or self.config or {}
|
config = config or self.config or {}
|
||||||
|
|
||||||
if issubclass(self.type_hint, ConfigMixin):
|
if issubclass(self.type_hint, ConfigMixin):
|
||||||
component = self.type_hint.from_config(config, **kwargs)
|
component = self.type_hint.from_config(config, **kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -193,24 +194,24 @@ class ComponentSpec:
|
|||||||
if k in signature_params:
|
if k in signature_params:
|
||||||
init_kwargs[k] = v
|
init_kwargs[k] = v
|
||||||
component = self.type_hint(**init_kwargs)
|
component = self.type_hint(**init_kwargs)
|
||||||
|
|
||||||
component._diffusers_load_id = "null"
|
component._diffusers_load_id = "null"
|
||||||
if hasattr(component, "config"):
|
if hasattr(component, "config"):
|
||||||
self.config = component.config
|
self.config = component.config
|
||||||
|
|
||||||
return component
|
return component
|
||||||
|
|
||||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||||
def create_from_pretrained(self, **kwargs) -> Any:
|
def create_from_pretrained(self, **kwargs) -> Any:
|
||||||
"""Create component using from_pretrained."""
|
"""Create component using from_pretrained."""
|
||||||
|
|
||||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||||
repo = load_kwargs.pop("repo", None)
|
repo = load_kwargs.pop("repo", None)
|
||||||
if repo is None:
|
if repo is None:
|
||||||
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
raise ValueError("`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||||
|
|
||||||
if self.type_hint is None:
|
if self.type_hint is None:
|
||||||
try:
|
try:
|
||||||
from diffusers import AutoModel
|
from diffusers import AutoModel
|
||||||
@@ -223,19 +224,19 @@ class ComponentSpec:
|
|||||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
|
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
|
||||||
|
|
||||||
if repo != self.repo:
|
if repo != self.repo:
|
||||||
self.repo = repo
|
self.repo = repo
|
||||||
for k, v in passed_loading_kwargs.items():
|
for k, v in passed_loading_kwargs.items():
|
||||||
if v is not None:
|
if v is not None:
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
component._diffusers_load_id = self.load_id
|
component._diffusers_load_id = self.load_id
|
||||||
|
|
||||||
return component
|
return component
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
|
@dataclass
|
||||||
class ConfigSpec:
|
class ConfigSpec:
|
||||||
"""Specification for a pipeline configuration parameter."""
|
"""Specification for a pipeline configuration parameter."""
|
||||||
name: str
|
name: str
|
||||||
@@ -254,7 +255,7 @@ class InputParam:
|
|||||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OutputParam:
|
class OutputParam:
|
||||||
"""Specification for an output parameter."""
|
"""Specification for an output parameter."""
|
||||||
name: str
|
name: str
|
||||||
@@ -287,14 +288,14 @@ def format_inputs_short(inputs):
|
|||||||
"""
|
"""
|
||||||
required_inputs = [param for param in inputs if param.required]
|
required_inputs = [param for param in inputs if param.required]
|
||||||
optional_inputs = [param for param in inputs if not param.required]
|
optional_inputs = [param for param in inputs if not param.required]
|
||||||
|
|
||||||
required_str = ", ".join(param.name for param in required_inputs)
|
required_str = ", ".join(param.name for param in required_inputs)
|
||||||
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
|
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
|
||||||
|
|
||||||
inputs_str = required_str
|
inputs_str = required_str
|
||||||
if optional_str:
|
if optional_str:
|
||||||
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
|
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
|
||||||
|
|
||||||
return inputs_str
|
return inputs_str
|
||||||
|
|
||||||
|
|
||||||
@@ -321,18 +322,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
|||||||
input_parts.append(f"Required({inp.name})")
|
input_parts.append(f"Required({inp.name})")
|
||||||
else:
|
else:
|
||||||
input_parts.append(inp.name)
|
input_parts.append(inp.name)
|
||||||
|
|
||||||
# Handle modified variables (appear in both inputs and outputs)
|
# Handle modified variables (appear in both inputs and outputs)
|
||||||
inputs_set = {inp.name for inp in intermediates_inputs}
|
inputs_set = {inp.name for inp in intermediates_inputs}
|
||||||
modified_parts = []
|
modified_parts = []
|
||||||
new_output_parts = []
|
new_output_parts = []
|
||||||
|
|
||||||
for out in intermediates_outputs:
|
for out in intermediates_outputs:
|
||||||
if out.name in inputs_set:
|
if out.name in inputs_set:
|
||||||
modified_parts.append(out.name)
|
modified_parts.append(out.name)
|
||||||
else:
|
else:
|
||||||
new_output_parts.append(out.name)
|
new_output_parts.append(out.name)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
if input_parts:
|
if input_parts:
|
||||||
result.append(f" - inputs: {', '.join(input_parts)}")
|
result.append(f" - inputs: {', '.join(input_parts)}")
|
||||||
@@ -340,7 +341,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
|||||||
result.append(f" - modified: {', '.join(modified_parts)}")
|
result.append(f" - modified: {', '.join(modified_parts)}")
|
||||||
if new_output_parts:
|
if new_output_parts:
|
||||||
result.append(f" - outputs: {', '.join(new_output_parts)}")
|
result.append(f" - outputs: {', '.join(new_output_parts)}")
|
||||||
|
|
||||||
return "\n".join(result) if result else " (none)"
|
return "\n".join(result) if result else " (none)"
|
||||||
|
|
||||||
|
|
||||||
@@ -358,18 +359,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
|||||||
"""
|
"""
|
||||||
if not params:
|
if not params:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
base_indent = " " * indent_level
|
base_indent = " " * indent_level
|
||||||
param_indent = " " * (indent_level + 4)
|
param_indent = " " * (indent_level + 4)
|
||||||
desc_indent = " " * (indent_level + 8)
|
desc_indent = " " * (indent_level + 8)
|
||||||
formatted_params = []
|
formatted_params = []
|
||||||
|
|
||||||
def get_type_str(type_hint):
|
def get_type_str(type_hint):
|
||||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
|
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
|
||||||
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
|
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
|
||||||
return f"Union[{', '.join(types)}]"
|
return f"Union[{', '.join(types)}]"
|
||||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||||
|
|
||||||
def wrap_text(text, indent, max_length):
|
def wrap_text(text, indent, max_length):
|
||||||
"""Wrap text while preserving markdown links and maintaining indentation."""
|
"""Wrap text while preserving markdown links and maintaining indentation."""
|
||||||
words = text.split()
|
words = text.split()
|
||||||
@@ -379,7 +380,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
|||||||
|
|
||||||
for word in words:
|
for word in words:
|
||||||
word_length = len(word) + (1 if current_line else 0)
|
word_length = len(word) + (1 if current_line else 0)
|
||||||
|
|
||||||
if current_line and current_length + word_length > max_length:
|
if current_line and current_length + word_length > max_length:
|
||||||
lines.append(" ".join(current_line))
|
lines.append(" ".join(current_line))
|
||||||
current_line = [word]
|
current_line = [word]
|
||||||
@@ -387,20 +388,20 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
|||||||
else:
|
else:
|
||||||
current_line.append(word)
|
current_line.append(word)
|
||||||
current_length += word_length
|
current_length += word_length
|
||||||
|
|
||||||
if current_line:
|
if current_line:
|
||||||
lines.append(" ".join(current_line))
|
lines.append(" ".join(current_line))
|
||||||
|
|
||||||
return f"\n{indent}".join(lines)
|
return f"\n{indent}".join(lines)
|
||||||
|
|
||||||
# Add the header
|
# Add the header
|
||||||
formatted_params.append(f"{base_indent}{header}:")
|
formatted_params.append(f"{base_indent}{header}:")
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
# Format parameter name and type
|
# Format parameter name and type
|
||||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||||
param_str = f"{param_indent}{param.name} (`{type_str}`"
|
param_str = f"{param_indent}{param.name} (`{type_str}`"
|
||||||
|
|
||||||
# Add optional tag and default value if parameter is an InputParam and optional
|
# Add optional tag and default value if parameter is an InputParam and optional
|
||||||
if hasattr(param, "required"):
|
if hasattr(param, "required"):
|
||||||
if not param.required:
|
if not param.required:
|
||||||
@@ -408,7 +409,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
|||||||
if param.default is not None:
|
if param.default is not None:
|
||||||
param_str += f", defaults to {param.default}"
|
param_str += f", defaults to {param.default}"
|
||||||
param_str += "):"
|
param_str += "):"
|
||||||
|
|
||||||
# Add description on a new line with additional indentation and wrapping
|
# Add description on a new line with additional indentation and wrapping
|
||||||
if param.description:
|
if param.description:
|
||||||
desc = re.sub(
|
desc = re.sub(
|
||||||
@@ -418,9 +419,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
|||||||
)
|
)
|
||||||
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
||||||
param_str += f"\n{desc_indent}{wrapped_desc}"
|
param_str += f"\n{desc_indent}{wrapped_desc}"
|
||||||
|
|
||||||
formatted_params.append(param_str)
|
formatted_params.append(param_str)
|
||||||
|
|
||||||
return "\n\n".join(formatted_params)
|
return "\n\n".join(formatted_params)
|
||||||
|
|
||||||
|
|
||||||
@@ -466,42 +467,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
|
|||||||
"""
|
"""
|
||||||
if not components:
|
if not components:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
base_indent = " " * indent_level
|
base_indent = " " * indent_level
|
||||||
component_indent = " " * (indent_level + 4)
|
component_indent = " " * (indent_level + 4)
|
||||||
formatted_components = []
|
formatted_components = []
|
||||||
|
|
||||||
# Add the header
|
# Add the header
|
||||||
formatted_components.append(f"{base_indent}Components:")
|
formatted_components.append(f"{base_indent}Components:")
|
||||||
if add_empty_lines:
|
if add_empty_lines:
|
||||||
formatted_components.append("")
|
formatted_components.append("")
|
||||||
|
|
||||||
# Add each component with optional empty lines between them
|
# Add each component with optional empty lines between them
|
||||||
for i, component in enumerate(components):
|
for i, component in enumerate(components):
|
||||||
# Get type name, handling special cases
|
# Get type name, handling special cases
|
||||||
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
|
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
|
||||||
|
|
||||||
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
|
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
|
||||||
if component.description:
|
if component.description:
|
||||||
component_desc += f": {component.description}"
|
component_desc += f": {component.description}"
|
||||||
|
|
||||||
# Get the loading fields dynamically
|
# Get the loading fields dynamically
|
||||||
loading_field_values = []
|
loading_field_values = []
|
||||||
for field_name in component.loading_fields():
|
for field_name in component.loading_fields():
|
||||||
field_value = getattr(component, field_name)
|
field_value = getattr(component, field_name)
|
||||||
if field_value is not None:
|
if field_value is not None:
|
||||||
loading_field_values.append(f"{field_name}={field_value}")
|
loading_field_values.append(f"{field_name}={field_value}")
|
||||||
|
|
||||||
# Add loading field information if available
|
# Add loading field information if available
|
||||||
if loading_field_values:
|
if loading_field_values:
|
||||||
component_desc += f" [{', '.join(loading_field_values)}]"
|
component_desc += f" [{', '.join(loading_field_values)}]"
|
||||||
|
|
||||||
formatted_components.append(component_desc)
|
formatted_components.append(component_desc)
|
||||||
|
|
||||||
# Add an empty line after each component except the last one
|
# Add an empty line after each component except the last one
|
||||||
if add_empty_lines and i < len(components) - 1:
|
if add_empty_lines and i < len(components) - 1:
|
||||||
formatted_components.append("")
|
formatted_components.append("")
|
||||||
|
|
||||||
return "\n".join(formatted_components)
|
return "\n".join(formatted_components)
|
||||||
|
|
||||||
|
|
||||||
@@ -519,27 +520,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
|||||||
"""
|
"""
|
||||||
if not configs:
|
if not configs:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
base_indent = " " * indent_level
|
base_indent = " " * indent_level
|
||||||
config_indent = " " * (indent_level + 4)
|
config_indent = " " * (indent_level + 4)
|
||||||
formatted_configs = []
|
formatted_configs = []
|
||||||
|
|
||||||
# Add the header
|
# Add the header
|
||||||
formatted_configs.append(f"{base_indent}Configs:")
|
formatted_configs.append(f"{base_indent}Configs:")
|
||||||
if add_empty_lines:
|
if add_empty_lines:
|
||||||
formatted_configs.append("")
|
formatted_configs.append("")
|
||||||
|
|
||||||
# Add each config with optional empty lines between them
|
# Add each config with optional empty lines between them
|
||||||
for i, config in enumerate(configs):
|
for i, config in enumerate(configs):
|
||||||
config_desc = f"{config_indent}{config.name} (default: {config.default})"
|
config_desc = f"{config_indent}{config.name} (default: {config.default})"
|
||||||
if config.description:
|
if config.description:
|
||||||
config_desc += f": {config.description}"
|
config_desc += f": {config.description}"
|
||||||
formatted_configs.append(config_desc)
|
formatted_configs.append(config_desc)
|
||||||
|
|
||||||
# Add an empty line after each config except the last one
|
# Add an empty line after each config except the last one
|
||||||
if add_empty_lines and i < len(configs) - 1:
|
if add_empty_lines and i < len(configs) - 1:
|
||||||
formatted_configs.append("")
|
formatted_configs.append("")
|
||||||
|
|
||||||
return "\n".join(formatted_configs)
|
return "\n".join(formatted_configs)
|
||||||
|
|
||||||
|
|
||||||
@@ -584,9 +585,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class
|
|||||||
|
|
||||||
# Add inputs section
|
# Add inputs section
|
||||||
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
|
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
|
||||||
|
|
||||||
# Add outputs section
|
# Add outputs section
|
||||||
output += "\n\n"
|
output += "\n\n"
|
||||||
output += format_output_params(outputs, indent_level=2)
|
output += format_output_params(outputs, indent_level=2)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -334,6 +334,7 @@ def maybe_raise_or_warn(
|
|||||||
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
|
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
|
||||||
def simple_get_class_obj(library_name, class_name):
|
def simple_get_class_obj(library_name, class_name):
|
||||||
from diffusers import pipelines
|
from diffusers import pipelines
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
|
|
||||||
if is_pipeline_module:
|
if is_pipeline_module:
|
||||||
@@ -345,6 +346,7 @@ def simple_get_class_obj(library_name, class_name):
|
|||||||
|
|
||||||
return class_obj
|
return class_obj
|
||||||
|
|
||||||
|
|
||||||
def get_class_obj_and_candidates(
|
def get_class_obj_and_candidates(
|
||||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1120,7 +1120,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||||
automatically detect the available accelerator and use.
|
automatically detect the available accelerator and use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||||
|
|
||||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
||||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||||
from .pipeline_stable_diffusion_xl_modular import (
|
from .pipeline_stable_diffusion_xl_modular import (
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
StableDiffusionXLControlNetDenoiseStep,
|
StableDiffusionXLControlNetDenoiseStep,
|
||||||
StableDiffusionXLDecodeLatentsStep,
|
StableDiffusionXLDecodeLatentsStep,
|
||||||
StableDiffusionXLDenoiseStep,
|
StableDiffusionXLDenoiseStep,
|
||||||
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLPrepareLatentsStep,
|
StableDiffusionXLPrepareLatentsStep,
|
||||||
StableDiffusionXLSetTimestepsStep,
|
StableDiffusionXLSetTimestepsStep,
|
||||||
StableDiffusionXLTextEncoderStep,
|
StableDiffusionXLTextEncoderStep,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
+287
-289
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utilities to dynamically load objects from the Hub."""
|
"""Utilities to dynamically load objects from the Hub."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@@ -21,8 +22,9 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, ModuleType, Optional, Union
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, model_info
|
from huggingface_hub import hf_hub_download, model_info
|
||||||
@@ -37,6 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
||||||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
||||||
|
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_diffusers_versions():
|
def get_diffusers_versions():
|
||||||
@@ -154,15 +157,132 @@ def check_imports(filename):
|
|||||||
return get_relative_imports(filename)
|
return get_relative_imports(filename)
|
||||||
|
|
||||||
|
|
||||||
def get_class_in_module(class_name, module_path):
|
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
|
||||||
|
if trust_remote_code is None:
|
||||||
|
if has_local_code:
|
||||||
|
trust_remote_code = False
|
||||||
|
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||||
|
prev_sig_handler = None
|
||||||
|
try:
|
||||||
|
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||||
|
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||||
|
while trust_remote_code is None:
|
||||||
|
answer = input(
|
||||||
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||||
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||||
|
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||||
|
f"Do you wish to run the custom code? [y/N] "
|
||||||
|
)
|
||||||
|
if answer.lower() in ["yes", "y", "1"]:
|
||||||
|
trust_remote_code = True
|
||||||
|
elif answer.lower() in ["no", "n", "0", ""]:
|
||||||
|
trust_remote_code = False
|
||||||
|
signal.alarm(0)
|
||||||
|
except Exception:
|
||||||
|
# OS which does not support signal.SIGALRM
|
||||||
|
raise ValueError(
|
||||||
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||||
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||||
|
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if prev_sig_handler is not None:
|
||||||
|
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||||
|
signal.alarm(0)
|
||||||
|
elif has_remote_code:
|
||||||
|
# For the CI which puts the timeout at 0
|
||||||
|
_raise_timeout_error(None, None)
|
||||||
|
|
||||||
|
if has_remote_code and not has_local_code and not trust_remote_code:
|
||||||
|
raise ValueError(
|
||||||
|
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||||
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||||
|
" set the option `trust_remote_code=True` to remove this error."
|
||||||
|
)
|
||||||
|
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_in_modular_module(
|
||||||
|
class_name: str,
|
||||||
|
module_path: Union[str, os.PathLike],
|
||||||
|
*,
|
||||||
|
force_reload: bool = False,
|
||||||
|
) -> type:
|
||||||
|
"""
|
||||||
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_name (`str`): The name of the class to import.
|
||||||
|
module_path (`str` or `os.PathLike`): The path to the module to import.
|
||||||
|
force_reload (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
|
||||||
|
Otherwise, the module is only reloaded if the file has changed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`typing.Type`: The class looked for.
|
||||||
|
"""
|
||||||
|
name = os.path.normpath(module_path)
|
||||||
|
if name.endswith(".py"):
|
||||||
|
name = name[:-3]
|
||||||
|
name = name.replace(os.path.sep, ".")
|
||||||
|
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||||
|
with _HF_REMOTE_CODE_LOCK:
|
||||||
|
if force_reload:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||||
|
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||||
|
|
||||||
|
# Hash the module file and all its relative imports to check if we need to reload it
|
||||||
|
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
|
||||||
|
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
|
||||||
|
|
||||||
|
module: ModuleType
|
||||||
|
if cached_module is None:
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
# insert it into sys.modules before any loading begins
|
||||||
|
sys.modules[name] = module
|
||||||
|
else:
|
||||||
|
module = cached_module
|
||||||
|
# reload in both cases, unless the module is already imported and the hash hits
|
||||||
|
if getattr(module, "__transformers_module_hash__", "") != module_hash:
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
module.__transformers_module_hash__ = module_hash
|
||||||
|
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_in_module(class_name, module_path, force_reload=False):
|
||||||
"""
|
"""
|
||||||
Import a module on the cache directory for modules and extract a class from it.
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
"""
|
"""
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
name = os.path.normpath(module_path)
|
||||||
module = importlib.import_module(module_path)
|
if name.endswith(".py"):
|
||||||
|
name = name[:-3]
|
||||||
|
name = name.replace(os.path.sep, ".")
|
||||||
|
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||||
|
|
||||||
|
with _HF_REMOTE_CODE_LOCK:
|
||||||
|
if force_reload:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||||
|
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||||
|
|
||||||
|
module: ModuleType
|
||||||
|
if cached_module is None:
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
# insert it into sys.modules before any loading begins
|
||||||
|
sys.modules[name] = module
|
||||||
|
else:
|
||||||
|
module = cached_module
|
||||||
|
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
return find_pipeline_class(module)
|
return find_pipeline_class(module)
|
||||||
|
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
@@ -203,6 +323,7 @@ def get_cached_module_file(
|
|||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
is_modular: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
||||||
@@ -257,7 +378,7 @@ def get_cached_module_file(
|
|||||||
if os.path.isfile(module_file_or_url):
|
if os.path.isfile(module_file_or_url):
|
||||||
resolved_module_file = module_file_or_url
|
resolved_module_file = module_file_or_url
|
||||||
submodule = "local"
|
submodule = "local"
|
||||||
elif pretrained_model_name_or_path.count("/") == 0:
|
elif pretrained_model_name_or_path.count("/") == 0 and not is_modular:
|
||||||
available_versions = get_diffusers_versions()
|
available_versions = get_diffusers_versions()
|
||||||
# cut ".dev0"
|
# cut ".dev0"
|
||||||
latest_version = "v" + ".".join(__version__.split(".")[:3])
|
latest_version = "v" + ".".join(__version__.split(".")[:3])
|
||||||
@@ -297,6 +418,24 @@ def get_cached_module_file(
|
|||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
elif is_modular:
|
||||||
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
resolved_module_file = hf_hub_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
module_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||||
|
except EnvironmentError:
|
||||||
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
@@ -381,6 +520,7 @@ def get_class_from_dynamic_module(
|
|||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
is_modular: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -453,5 +593,7 @@ def get_class_from_dynamic_module(
|
|||||||
token=token,
|
token=token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
is_modular=is_modular,
|
||||||
)
|
)
|
||||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
__import__("ipdb").set_trace()
|
||||||
|
return get_class_in_module(class_name, final_module)
|
||||||
|
|||||||
Reference in New Issue
Block a user