Compare commits

...

1 Commits

Author SHA1 Message Date
DN6 c8a7617536 update 2025-05-12 19:37:28 +05:30
21 changed files with 927 additions and 754 deletions
+2 -2
View File
@@ -761,8 +761,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LayerSkipConfig, LayerSkipConfig,
PyramidAttentionBroadcastConfig, PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig, SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache, apply_faster_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast, apply_pyramid_attention_broadcast,
) )
from .models import ( from .models import (
@@ -1085,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionSAGPipeline, StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
StableDiffusionXLAdapterPipeline, StableDiffusionXLAdapterPipeline,
StableDiffusionXLAutoPipeline,
StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline,
@@ -1102,7 +1103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline, StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
StableVideoDiffusionPipeline, StableVideoDiffusionPipeline,
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
+2 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Union, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
+2 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Union, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Union, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, List, TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional
import torch import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING: if TYPE_CHECKING:
from ..pipelines.modular_pipeline import BlockState from ..pipelines.modular_pipeline import BlockState
+6 -1
View File
@@ -20,7 +20,12 @@ import torch
from ..utils import get_logger from ..utils import get_logger
from ..utils.torch_utils import unwrap_module from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn from ._common import (
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
_ATTENTION_CLASSES,
_FEEDFORWARD_CLASSES,
_get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook from .hooks import HookRegistry, ModelHook
@@ -14,7 +14,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
+1 -1
View File
@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import ( from .ip_adapter import (
FluxIPAdapterMixin, FluxIPAdapterMixin,
IPAdapterMixin, IPAdapterMixin,
SD3IPAdapterMixin,
ModularIPAdapterMixin, ModularIPAdapterMixin,
SD3IPAdapterMixin,
) )
from .lora_pipeline import ( from .lora_pipeline import (
AmusedLoraLoaderMixin, AmusedLoraLoaderMixin,
+1 -1
View File
@@ -703,12 +703,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_sag import StableDiffusionSAGPipeline from .stable_diffusion_sag import StableDiffusionSAGPipeline
from .stable_diffusion_xl import ( from .stable_diffusion_xl import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline, StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularLoader, StableDiffusionXLModularLoader,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
) )
from .stable_video_diffusion import StableVideoDiffusionPipeline from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import ( from .t2i_adapter import (
@@ -12,21 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import time
from collections import OrderedDict from collections import OrderedDict
from itertools import combinations from itertools import combinations
from typing import List, Optional, Union, Dict, Any from typing import Any, Dict, List, Optional, Union
import copy
import torch import torch
import time
from dataclasses import dataclass
from ..utils import ( from ..utils import (
is_accelerate_available, is_accelerate_available,
logging, logging,
) )
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
if is_accelerate_available(): if is_accelerate_available():
@@ -231,8 +228,9 @@ class AutoOffloadStrategy:
from .modular_pipeline_utils import ComponentSpec
import uuid import uuid
class ComponentsManager: class ComponentsManager:
def __init__(self): def __init__(self):
self.components = OrderedDict() self.components = OrderedDict()
@@ -726,7 +724,7 @@ class ComponentsManager:
if info.get("adapters") is not None: if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n" output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"): if info.get("ip_adapter"):
output += f" IP-Adapter: Enabled\n" output += " IP-Adapter: Enabled\n"
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
return output return output
+156 -136
View File
@@ -12,29 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os
import traceback import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Union, Optional, Type from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from tqdm.auto import tqdm
import re
import os
import importlib
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from tqdm.auto import tqdm
from ..configuration_utils import ConfigMixin, FrozenDict from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils import ( from ..utils import (
is_accelerate_available,
is_accelerate_version,
logging,
PushToHubMixin, PushToHubMixin,
is_accelerate_available,
logging,
) )
from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from .components_manager import ComponentsManager
from .modular_pipeline_utils import ( from .modular_pipeline_utils import (
ComponentSpec, ComponentSpec,
ConfigSpec, ConfigSpec,
@@ -42,18 +40,15 @@ from .modular_pipeline_utils import (
OutputParam, OutputParam,
format_components, format_components,
format_configs, format_configs,
format_input_params,
format_inputs_short, format_inputs_short,
format_intermediates_short, format_intermediates_short,
format_output_params,
format_params,
make_doc_string, make_doc_string,
) )
from .components_manager import ComponentsManager from .pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
from copy import deepcopy
if is_accelerate_available(): if is_accelerate_available():
import accelerate pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -108,10 +103,7 @@ class PipelineState:
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
return ( return (
f"PipelineState(\n" f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" f" intermediates={{\n{intermediates}\n }}\n" f")"
f" inputs={{\n{inputs}\n }},\n"
f" intermediates={{\n{intermediates}\n }}\n"
f")"
) )
@@ -120,6 +112,7 @@ class BlockState:
""" """
Container for block state data with attribute access and formatted representation. Container for block state data with attribute access and formatted representation.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self, key, value) setattr(self, key, value)
@@ -158,20 +151,66 @@ class BlockState:
return f"BlockState(\n{attributes}\n)" return f"BlockState(\n{attributes}\n)"
class ModularPipelineMixin(ConfigMixin):
class ModularPipelineMixin:
""" """
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
""" """
config_name = "config.json"
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): @classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
hub_kwargs_names = [
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"subfolder",
"token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, False, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError("")
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
block_cls = get_class_from_dynamic_module(
pretrained_model_name_or_path,
module_file=module_file,
class_name=class_name,
is_modular=True,
**hub_kwargs,
**kwargs,
)
return block_cls()
def setup_loader(
self,
modular_repo: Optional[Union[str, os.PathLike]] = None,
component_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
):
""" """
create a mouldar loader, optionally accept modular_repo to load from hub. create a ModularLoader, optionally accept modular_repo to load from hub.
""" """
# Import components loader (it is model-specific class) # Import components loader (it is model-specific class)
loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
diffusers_module = importlib.import_module("diffusers") diffusers_module = importlib.import_module("diffusers")
loader_class = getattr(diffusers_module, loader_class_name) loader_class = getattr(diffusers_module, loader_class_name)
@@ -181,8 +220,9 @@ class ModularPipelineMixin:
# Create the loader with the updated specs # Create the loader with the updated specs
specs = component_specs + config_specs specs = component_specs + config_specs
self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) self.loader = loader_class(
specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection
)
@property @property
def default_call_parameters(self) -> Dict[str, Any]: def default_call_parameters(self) -> Dict[str, Any]:
@@ -238,7 +278,6 @@ class ModularPipelineMixin:
if output is None: if output is None:
return state return state
elif isinstance(output, str): elif isinstance(output, str):
return state.get_intermediate(output) return state.get_intermediate(output)
@@ -268,7 +307,6 @@ class ModularPipelineMixin:
class PipelineBlock(ModularPipelineMixin): class PipelineBlock(ModularPipelineMixin):
model_name = None model_name = None
@property @property
@@ -284,7 +322,6 @@ class PipelineBlock(ModularPipelineMixin):
def expected_configs(self) -> List[ConfigSpec]: def expected_configs(self) -> List[ConfigSpec]:
return [] return []
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
@property @property
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
@@ -322,7 +359,6 @@ class PipelineBlock(ModularPipelineMixin):
input_names.append(input_param.name) input_names.append(input_param.name)
return input_names return input_names
def __call__(self, pipeline, state: PipelineState) -> PipelineState: def __call__(self, pipeline, state: PipelineState) -> PipelineState:
raise NotImplementedError("__call__ method must be implemented in subclasses") raise NotImplementedError("__call__ method must be implemented in subclasses")
@@ -331,14 +367,14 @@ class PipelineBlock(ModularPipelineMixin):
base_class = self.__class__.__bases__[0].__name__ base_class = self.__class__.__bases__[0].__name__
# Format description with proper indentation # Format description with proper indentation
desc_lines = self.description.split('\n') desc_lines = self.description.split("\n")
desc = [] desc = []
# First line with "Description:" label # First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}") desc.append(f" Description: {desc_lines[0]}")
# Subsequent lines with proper indentation # Subsequent lines with proper indentation
if len(desc_lines) > 1: if len(desc_lines) > 1:
desc.extend(f" {line}" for line in desc_lines[1:]) desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n' desc = "\n".join(desc) + "\n"
# Components section - use format_components with add_empty_lines=False # Components section - use format_components with add_empty_lines=False
expected_components = getattr(self, "expected_components", []) expected_components = getattr(self, "expected_components", [])
@@ -355,7 +391,9 @@ class PipelineBlock(ModularPipelineMixin):
inputs = "Inputs:\n " + inputs_str inputs = "Inputs:\n " + inputs_str
# Intermediates section # Intermediates section
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) intermediates_str = format_intermediates_short(
self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs
)
intermediates = f"Intermediates:\n{intermediates_str}" intermediates = f"Intermediates:\n{intermediates_str}"
return ( return (
@@ -369,7 +407,6 @@ class PipelineBlock(ModularPipelineMixin):
f")" f")"
) )
@property @property
def doc(self): def doc(self):
return make_doc_string( return make_doc_string(
@@ -379,10 +416,9 @@ class PipelineBlock(ModularPipelineMixin):
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
expected_components=self.expected_components, expected_components=self.expected_components,
expected_configs=self.expected_configs expected_configs=self.expected_configs,
) )
def get_block_state(self, state: PipelineState) -> dict: def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary""" """Get all inputs and intermediates in one dictionary"""
data = {} data = {}
@@ -429,9 +465,11 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
for input_param in inputs: for input_param in inputs:
if input_param.name in combined_dict: if input_param.name in combined_dict:
current_param = combined_dict[input_param.name] current_param = combined_dict[input_param.name]
if (current_param.default is not None and if (
input_param.default is not None and current_param.default is not None
current_param.default != input_param.default): and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn( warnings.warn(
f"Multiple different default values found for input '{input_param.name}': " f"Multiple different default values found for input '{input_param.name}': "
f"{current_param.default} (from block '{value_sources[input_param.name]}') and " f"{current_param.default} (from block '{value_sources[input_param.name]}') and "
@@ -446,6 +484,7 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
return list(combined_dict.values()) return list(combined_dict.values())
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
""" """
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs,
@@ -487,15 +526,15 @@ class AutoPipelineBlocks(ModularPipelineMixin):
blocks[block_name] = block_cls() blocks[block_name] = block_cls()
self.blocks = blocks self.blocks = blocks
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
)
default_blocks = [t for t in self.block_trigger_inputs if t is None] default_blocks = [t for t in self.block_trigger_inputs if t is None]
# can only have 1 or 0 default block, and has to put in the last # can only have 1 or 0 default block, and has to put in the last
# the order of blocksmatters here because the first block with matching trigger will be dispatched # the order of blocksmatters here because the first block with matching trigger will be dispatched
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
# if both mask and image are provided, it is inpaint; if only image is provided, it is img2img # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img
if len(default_blocks) > 1 or ( if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None
):
raise ValueError( raise ValueError(
f"In {self.__class__.__name__}, exactly one None must be specified as the last element " f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
"in block_trigger_inputs." "in block_trigger_inputs."
@@ -532,7 +571,6 @@ class AutoPipelineBlocks(ModularPipelineMixin):
expected_configs.append(config) expected_configs.append(config)
return expected_configs return expected_configs
@property @property
def required_inputs(self) -> List[str]: def required_inputs(self) -> List[str]:
first_block = next(iter(self.blocks.values())) first_block = next(iter(self.blocks.values()))
@@ -557,7 +595,6 @@ class AutoPipelineBlocks(ModularPipelineMixin):
return list(required_by_all) return list(required_by_all)
# YiYi TODO: add test for this # YiYi TODO: add test for this
@property @property
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
@@ -571,7 +608,6 @@ class AutoPipelineBlocks(ModularPipelineMixin):
input_param.required = False input_param.required = False
return combined_inputs return combined_inputs
@property @property
def intermediates_inputs(self) -> List[str]: def intermediates_inputs(self) -> List[str]:
named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()]
@@ -630,18 +666,19 @@ class AutoPipelineBlocks(ModularPipelineMixin):
Returns a set of all unique trigger input values found in the blocks. Returns a set of all unique trigger input values found in the blocks.
Returns: Set[str] containing all unique block_trigger_inputs values Returns: Set[str] containing all unique block_trigger_inputs values
""" """
def fn_recursive_get_trigger(blocks): def fn_recursive_get_trigger(blocks):
trigger_values = set() trigger_values = set()
if blocks is not None: if blocks is not None:
for name, block in blocks.items(): for name, block in blocks.items():
# Check if current block has trigger inputs(i.e. auto block) # Check if current block has trigger inputs(i.e. auto block)
if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
# Add all non-None values from the trigger inputs list # Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None) trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has blocks, recursively check them # If block has blocks, recursively check them
if hasattr(block, 'blocks'): if hasattr(block, "blocks"):
nested_triggers = fn_recursive_get_trigger(block.blocks) nested_triggers = fn_recursive_get_trigger(block.blocks)
trigger_values.update(nested_triggers) trigger_values.update(nested_triggers)
@@ -660,12 +697,9 @@ class AutoPipelineBlocks(ModularPipelineMixin):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__ base_class = self.__class__.__bases__[0].__name__
header = ( header = (
f"{class_name}(\n Class: {base_class}\n" f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
if base_class and base_class != "object"
else f"{class_name}(\n"
) )
if self.trigger_inputs: if self.trigger_inputs:
header += "\n" header += "\n"
header += " " + "=" * 100 + "\n" header += " " + "=" * 100 + "\n"
@@ -677,14 +711,14 @@ class AutoPipelineBlocks(ModularPipelineMixin):
header += " " + "=" * 100 + "\n\n" header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation # Format description with proper indentation
desc_lines = self.description.split('\n') desc_lines = self.description.split("\n")
desc = [] desc = []
# First line with "Description:" label # First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}") desc.append(f" Description: {desc_lines[0]}")
# Subsequent lines with proper indentation # Subsequent lines with proper indentation
if len(desc_lines) > 1: if len(desc_lines) > 1:
desc.extend(f" {line}" for line in desc_lines[1:]) desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n' desc = "\n".join(desc) + "\n"
# Components section - focus only on expected components # Components section - focus only on expected components
expected_components = getattr(self, "expected_components", []) expected_components = getattr(self, "expected_components", [])
@@ -699,7 +733,7 @@ class AutoPipelineBlocks(ModularPipelineMixin):
for i, (name, block) in enumerate(self.blocks.items()): for i, (name, block) in enumerate(self.blocks.items()):
# Get trigger input for this block # Get trigger input for this block
trigger = None trigger = None
if hasattr(self, 'block_to_trigger_map'): if hasattr(self, "block_to_trigger_map"):
trigger = self.block_to_trigger_map.get(name) trigger = self.block_to_trigger_map.get(name)
# Format the trigger info # Format the trigger info
if trigger is None: if trigger is None:
@@ -715,21 +749,13 @@ class AutoPipelineBlocks(ModularPipelineMixin):
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# Add block description # Add block description
desc_lines = block.description.split('\n') desc_lines = block.description.split("\n")
indented_desc = desc_lines[0] indented_desc = desc_lines[0]
if len(desc_lines) > 1: if len(desc_lines) > 1:
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n" blocks_str += f" Description: {indented_desc}\n\n"
return ( return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")"
f"{header}\n"
f"{desc}\n\n"
f"{components_str}\n\n"
f"{configs_str}\n\n"
f"{blocks_str}"
f")"
)
@property @property
def doc(self): def doc(self):
@@ -740,13 +766,15 @@ class AutoPipelineBlocks(ModularPipelineMixin):
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
expected_components=self.expected_components, expected_components=self.expected_components,
expected_configs=self.expected_configs expected_configs=self.expected_configs,
) )
class SequentialPipelineBlocks(ModularPipelineMixin): class SequentialPipelineBlocks(ModularPipelineMixin):
""" """
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
""" """
block_classes = [] block_classes = []
block_names = [] block_names = []
@@ -798,7 +826,6 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
blocks[block_name] = block_cls() blocks[block_name] = block_cls()
self.blocks = blocks self.blocks = blocks
@property @property
def required_inputs(self) -> List[str]: def required_inputs(self) -> List[str]:
# Get the first block from the dictionary # Get the first block from the dictionary
@@ -884,18 +911,19 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
Returns a set of all unique trigger input values found in the blocks. Returns a set of all unique trigger input values found in the blocks.
Returns: Set[str] containing all unique block_trigger_inputs values Returns: Set[str] containing all unique block_trigger_inputs values
""" """
def fn_recursive_get_trigger(blocks): def fn_recursive_get_trigger(blocks):
trigger_values = set() trigger_values = set()
if blocks is not None: if blocks is not None:
for name, block in blocks.items(): for name, block in blocks.items():
# Check if current block has trigger inputs(i.e. auto block) # Check if current block has trigger inputs(i.e. auto block)
if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
# Add all non-None values from the trigger inputs list # Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None) trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has blocks, recursively check them # If block has blocks, recursively check them
if hasattr(block, 'blocks'): if hasattr(block, "blocks"):
nested_triggers = fn_recursive_get_trigger(block.blocks) nested_triggers = fn_recursive_get_trigger(block.blocks)
trigger_values.update(nested_triggers) trigger_values.update(nested_triggers)
@@ -915,8 +943,8 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
result_blocks = OrderedDict() result_blocks = OrderedDict()
# sequential or PipelineBlock # sequential or PipelineBlock
if not hasattr(block, 'block_trigger_inputs'): if not hasattr(block, "block_trigger_inputs"):
if hasattr(block, 'blocks'): if hasattr(block, "blocks"):
# sequential # sequential
for block_name, block in block.blocks.items(): for block_name, block in block.blocks.items():
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
@@ -925,7 +953,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
# PipelineBlock # PipelineBlock
result_blocks[block_name] = block result_blocks[block_name] = block
# Add this block's output names to active triggers if defined # Add this block's output names to active triggers if defined
if hasattr(block, 'outputs'): if hasattr(block, "outputs"):
active_triggers.update(out.name for out in block.outputs) active_triggers.update(out.name for out in block.outputs)
return result_blocks return result_blocks
@@ -947,13 +975,13 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
if this_block is not None: if this_block is not None:
# sequential/auto # sequential/auto
if hasattr(this_block, 'blocks'): if hasattr(this_block, "blocks"):
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
else: else:
# PipelineBlock # PipelineBlock
result_blocks[block_name] = this_block result_blocks[block_name] = this_block
# Add this block's output names to active triggers if defined # Add this block's output names to active triggers if defined
if hasattr(this_block, 'outputs'): if hasattr(this_block, "outputs"):
active_triggers.update(out.name for out in this_block.outputs) active_triggers.update(out.name for out in this_block.outputs)
return result_blocks return result_blocks
@@ -968,7 +996,6 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
trigger_inputs_all = self.trigger_inputs trigger_inputs_all = self.trigger_inputs
if trigger_inputs is not None: if trigger_inputs is not None:
if not isinstance(trigger_inputs, (list, tuple, set)): if not isinstance(trigger_inputs, (list, tuple, set)):
trigger_inputs = [trigger_inputs] trigger_inputs = [trigger_inputs]
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
@@ -990,12 +1017,9 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__ base_class = self.__class__.__bases__[0].__name__
header = ( header = (
f"{class_name}(\n Class: {base_class}\n" f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
if base_class and base_class != "object"
else f"{class_name}(\n"
) )
if self.trigger_inputs: if self.trigger_inputs:
header += "\n" header += "\n"
header += " " + "=" * 100 + "\n" header += " " + "=" * 100 + "\n"
@@ -1007,14 +1031,14 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
header += " " + "=" * 100 + "\n\n" header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation # Format description with proper indentation
desc_lines = self.description.split('\n') desc_lines = self.description.split("\n")
desc = [] desc = []
# First line with "Description:" label # First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}") desc.append(f" Description: {desc_lines[0]}")
# Subsequent lines with proper indentation # Subsequent lines with proper indentation
if len(desc_lines) > 1: if len(desc_lines) > 1:
desc.extend(f" {line}" for line in desc_lines[1:]) desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n' desc = "\n".join(desc) + "\n"
# Components section - focus only on expected components # Components section - focus only on expected components
expected_components = getattr(self, "expected_components", []) expected_components = getattr(self, "expected_components", [])
@@ -1029,7 +1053,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
for i, (name, block) in enumerate(self.blocks.items()): for i, (name, block) in enumerate(self.blocks.items()):
# Get trigger input for this block # Get trigger input for this block
trigger = None trigger = None
if hasattr(self, 'block_to_trigger_map'): if hasattr(self, "block_to_trigger_map"):
trigger = self.block_to_trigger_map.get(name) trigger = self.block_to_trigger_map.get(name)
# Format the trigger info # Format the trigger info
if trigger is None: if trigger is None:
@@ -1045,21 +1069,13 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# Add block description # Add block description
desc_lines = block.description.split('\n') desc_lines = block.description.split("\n")
indented_desc = desc_lines[0] indented_desc = desc_lines[0]
if len(desc_lines) > 1: if len(desc_lines) > 1:
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n" blocks_str += f" Description: {indented_desc}\n\n"
return ( return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")"
f"{header}\n"
f"{desc}\n\n"
f"{components_str}\n\n"
f"{configs_str}\n\n"
f"{blocks_str}"
f")"
)
@property @property
def doc(self): def doc(self):
@@ -1070,11 +1086,10 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
expected_components=self.expected_components, expected_components=self.expected_components,
expected_configs=self.expected_configs expected_configs=self.expected_configs,
) )
# YiYi TODO: # YiYi TODO:
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader
@@ -1084,8 +1099,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
Base class for all Modular pipelines loaders. Base class for all Modular pipelines loaders.
""" """
config_name = "modular_model_index.json"
config_name = "modular_model_index.json"
def register_components(self, **kwargs): def register_components(self, **kwargs):
""" """
@@ -1097,7 +1112,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
""" """
for name, module in kwargs.items(): for name, module in kwargs.items():
# current component spec # current component spec
component_spec = self._component_specs.get(name) component_spec = self._component_specs.get(name)
if component_spec is None: if component_spec is None:
@@ -1107,7 +1121,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
is_registered = hasattr(self, name) is_registered = hasattr(self, name)
if module is not None and not hasattr(module, "_diffusers_load_id"): if module is not None and not hasattr(module, "_diffusers_load_id"):
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.")
# actual library and class name of the module # actual library and class name of the module
@@ -1143,12 +1157,20 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
current_module = getattr(self, name, None) current_module = getattr(self, name, None)
# skip if the component is already registered with the same object # skip if the component is already registered with the same object
if current_module is module: if current_module is module:
logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") logger.info(
f"ModularLoader.register_components: {name} is already registered with same object, skipping"
)
continue continue
# it module is not an instance of the expected type, still register it but with a warning # it module is not an instance of the expected type, still register it but with a warning
if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): if (
logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") module is not None
and component_spec.type_hint is not None
and not isinstance(module, component_spec.type_hint)
):
logger.warning(
f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}"
)
# warn if unregister # warn if unregister
if current_module is not None and module is None: if current_module is not None and module is None:
@@ -1157,10 +1179,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
f"(was {current_module.__class__.__name__})" f"(was {current_module.__class__.__name__})"
) )
# same type, new instance → debug # same type, new instance → debug
elif current_module is not None \ elif (
and module is not None \ current_module is not None
and isinstance(module, current_module.__class__) \ and module is not None
and current_module != module: and isinstance(module, current_module.__class__)
and current_module != module
):
logger.debug( logger.debug(
f"ModularLoader.register_components: replacing existing '{name}' " f"ModularLoader.register_components: replacing existing '{name}' "
f"(same type {type(current_module).__name__}, new instance)" f"(same type {type(current_module).__name__}, new instance)"
@@ -1175,28 +1199,34 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
if module is not None and self._component_manager is not None: if module is not None and self._component_manager is not None:
self._component_manager.add(name, module, self._collection) self._component_manager.add(name, module, self._collection)
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): def __init__(
self,
specs: List[Union[ComponentSpec, ConfigSpec]],
modular_repo: Optional[str] = None,
component_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
**kwargs,
):
""" """
Initialize the loader with a list of component specs and config specs. Initialize the loader with a list of component specs and config specs.
""" """
self._component_manager = component_manager self._component_manager = component_manager
self._collection = collection self._collection = collection
self._component_specs = { self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)}
spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)}
}
self._config_specs = {
spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)
}
# update component_specs and config_specs from modular_repo # update component_specs and config_specs from modular_repo
if modular_repo is not None: if modular_repo is not None:
config_dict = self.load_config(modular_repo, **kwargs) config_dict = self.load_config(modular_repo, **kwargs)
for name, value in config_dict.items(): for name, value in config_dict.items():
if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: if (
name in self._component_specs
and self._component_specs[name].default_creation_method == "from_pretrained"
and isinstance(value, (tuple, list))
and len(value) == 3
):
library, class_name, component_spec_dict = value library, class_name, component_spec_dict = value
component_spec = self._dict_to_component_spec(name, component_spec_dict) component_spec = self._dict_to_component_spec(name, component_spec_dict)
self._component_specs[name] = component_spec self._component_specs[name] = component_spec
@@ -1214,7 +1244,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
default_configs[name] = config_spec.default default_configs[name] = config_spec.default
self.register_to_config(**default_configs) self.register_to_config(**default_configs)
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
r""" r"""
@@ -1280,15 +1309,10 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
return torch.float32 return torch.float32
@property @property
def components(self) -> Dict[str, Any]: def components(self) -> Dict[str, Any]:
# return only components we've actually set as attributes on self # return only components we've actually set as attributes on self
return { return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
name: getattr(self, name)
for name in self._component_specs.keys()
if hasattr(self, name)
}
def update(self, **kwargs): def update(self, **kwargs):
""" """
@@ -1340,24 +1364,20 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
for name, component in passed_components.items(): for name, component in passed_components.items():
if not hasattr(component, "_diffusers_load_id"): if not hasattr(component, "_diffusers_load_id"):
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.")
if len(kwargs) > 0: if len(kwargs) > 0:
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
self.register_components(**passed_components) self.register_components(**passed_components)
config_to_register = {} config_to_register = {}
for name, new_value in passed_config_values.items(): for name, new_value in passed_config_values.items():
# e.g. requires_aesthetics_score = False # e.g. requires_aesthetics_score = False
self._config_specs[name].default = new_value self._config_specs[name].default = new_value
config_to_register[name] = new_value config_to_register[name] = new_value
self.register_to_config(**config_to_register) self.register_to_config(**config_to_register)
# YiYi TODO: support map for additional from_pretrained kwargs # YiYi TODO: support map for additional from_pretrained kwargs
def load(self, component_names: Optional[List[str]] = None, **kwargs): def load(self, component_names: Optional[List[str]] = None, **kwargs):
""" """
@@ -1410,8 +1430,9 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
# YiYi TODO: # YiYi TODO:
# 1. should support save some components too! currently only modular_model_index.json is saved # 1. should support save some components too! currently only modular_model_index.json is saved
# 2. maybe order the json file to make it more readable: configs first, then components # 2. maybe order the json file to make it more readable: configs first, then components
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): def save_pretrained(
self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs
):
component_names = list(self._component_specs.keys()) component_names = list(self._component_specs.keys())
config_names = list(self._config_specs.keys()) config_names = list(self._config_specs.keys())
self.register_to_config(_components_names=component_names, _configs_names=config_names) self.register_to_config(_components_names=component_names, _configs_names=config_names)
@@ -1421,11 +1442,11 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
config.pop("_configs_names", None) config.pop("_configs_names", None)
self._internal_dict = FrozenDict(config) self._internal_dict = FrozenDict(config)
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs
):
config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs)
expected_component = set(config_dict.pop("_components_names")) expected_component = set(config_dict.pop("_components_names"))
expected_config = set(config_dict.pop("_configs_names")) expected_config = set(config_dict.pop("_configs_names"))
@@ -1450,7 +1471,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) component_specs.append(ComponentSpec(name=name, default_creation_method="from_config"))
return cls(component_specs + config_specs) return cls(component_specs + config_specs)
@staticmethod @staticmethod
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
""" """
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import inspect import inspect
from dataclasses import dataclass, asdict, field, fields import re
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils.import_utils import is_torch_available from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
if is_torch_available(): if is_torch_available():
import torch import torch
@@ -176,7 +177,7 @@ class ComponentSpec:
if self.type_hint is None or not isinstance(self.type_hint, type): if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError( raise ValueError(
f"`type_hint` is required when using from_config creation method." "`type_hint` is required when using from_config creation method."
) )
config = config or self.config or {} config = config or self.config or {}
@@ -209,7 +210,7 @@ class ComponentSpec:
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None) repo = load_kwargs.pop("repo", None)
if repo is None: if repo is None:
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") raise ValueError("`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None: if self.type_hint is None:
try: try:
@@ -334,6 +334,7 @@ def maybe_raise_or_warn(
# a simpler version of get_class_obj_and_candidates, it won't work with custom code # a simpler version of get_class_obj_and_candidates, it won't work with custom code
def simple_get_class_obj(library_name, class_name): def simple_get_class_obj(library_name, class_name):
from diffusers import pipelines from diffusers import pipelines
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
if is_pipeline_module: if is_pipeline_module:
@@ -345,6 +346,7 @@ def simple_get_class_obj(library_name, class_name):
return class_obj return class_obj
def get_class_obj_and_candidates( def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
): ):
@@ -61,6 +61,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
from .pipeline_stable_diffusion_xl_modular import ( from .pipeline_stable_diffusion_xl_modular import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDecodeLatentsStep,
StableDiffusionXLDenoiseStep, StableDiffusionXLDenoiseStep,
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep, StableDiffusionXLSetTimestepsStep,
StableDiffusionXLTextEncoderStep, StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoPipeline,
) )
try: try:
@@ -13,17 +13,28 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any, List, Optional, Tuple, Union, Dict from collections import OrderedDict
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import PIL import PIL
import torch import torch
from collections import OrderedDict from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...image_processor import VaeImageProcessor, PipelineImageInput from ...configuration_utils import FrozenDict
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...guiders import ClassifierFreeGuidance
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import EulerDiscreteScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
logging, logging,
@@ -34,33 +45,20 @@ from ...utils.torch_utils import randn_tensor, unwrap_module
from ..controlnet.multicontrolnet import MultiControlNetModel from ..controlnet.multicontrolnet import MultiControlNetModel
from ..modular_pipeline import ( from ..modular_pipeline import (
AutoPipelineBlocks, AutoPipelineBlocks,
ModularLoader,
PipelineBlock,
PipelineState,
InputParam,
OutputParam,
SequentialPipelineBlocks,
ComponentSpec, ComponentSpec,
ConfigSpec, ConfigSpec,
InputParam,
ModularLoader,
OutputParam,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
) )
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import ( from .pipeline_output import (
StableDiffusionXLPipelineOutput, StableDiffusionXLPipelineOutput,
) )
from transformers import (
CLIPTextModel,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...schedulers import EulerDiscreteScheduler
from ...guiders import ClassifierFreeGuidance
from ...configuration_utils import FrozenDict
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+148 -6
View File
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Utilities to dynamically load objects from the Hub.""" """Utilities to dynamically load objects from the Hub."""
import hashlib
import importlib import importlib
import inspect import inspect
import json import json
@@ -21,8 +22,9 @@ import os
import re import re
import shutil import shutil
import sys import sys
import threading
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, ModuleType, Optional, Union
from urllib import request from urllib import request
from huggingface_hub import hf_hub_download, model_info from huggingface_hub import hf_hub_download, model_info
@@ -37,6 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions(): def get_diffusers_versions():
@@ -154,15 +157,132 @@ def check_imports(filename):
return get_relative_imports(filename) return get_relative_imports(filename)
def get_class_in_module(class_name, module_path): def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
def get_class_in_modular_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
) -> type:
"""
Import a module on the cache directory for modules and extract a class from it.
Args:
class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import.
force_reload (`bool`, *optional*, defaults to `False`):
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
Otherwise, the module is only reloaded if the file has changed.
Returns:
`typing.Type`: The class looked for.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
# Hash the module file and all its relative imports to check if we need to reload it
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
# reload in both cases, unless the module is already imported and the hash hits
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash
return getattr(module, class_name)
def get_class_in_module(class_name, module_path, force_reload=False):
""" """
Import a module on the cache directory for modules and extract a class from it. Import a module on the cache directory for modules and extract a class from it.
""" """
module_path = module_path.replace(os.path.sep, ".") name = os.path.normpath(module_path)
module = importlib.import_module(module_path) if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
module_spec.loader.exec_module(module)
if class_name is None: if class_name is None:
return find_pipeline_class(module) return find_pipeline_class(module)
return getattr(module, class_name) return getattr(module, class_name)
@@ -203,6 +323,7 @@ def get_cached_module_file(
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
is_modular: bool = False,
): ):
""" """
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -257,7 +378,7 @@ def get_cached_module_file(
if os.path.isfile(module_file_or_url): if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url resolved_module_file = module_file_or_url
submodule = "local" submodule = "local"
elif pretrained_model_name_or_path.count("/") == 0: elif pretrained_model_name_or_path.count("/") == 0 and not is_modular:
available_versions = get_diffusers_versions() available_versions = get_diffusers_versions()
# cut ".dev0" # cut ".dev0"
latest_version = "v" + ".".join(__version__.split(".")[:3]) latest_version = "v" + ".".join(__version__.split(".")[:3])
@@ -297,6 +418,24 @@ def get_cached_module_file(
except EnvironmentError: except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise raise
elif is_modular:
try:
# Load from URL or cache if already cached
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
else: else:
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
@@ -381,6 +520,7 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
is_modular: bool = False,
**kwargs, **kwargs,
): ):
""" """
@@ -453,5 +593,7 @@ def get_class_from_dynamic_module(
token=token, token=token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
is_modular=is_modular,
) )
return get_class_in_module(class_name, final_module.replace(".py", "")) __import__("ipdb").set_trace()
return get_class_in_module(class_name, final_module)