Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c8a7617536 |
@@ -761,8 +761,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
LayerSkipConfig,
|
LayerSkipConfig,
|
||||||
PyramidAttentionBroadcastConfig,
|
PyramidAttentionBroadcastConfig,
|
||||||
SmoothedEnergyGuidanceConfig,
|
SmoothedEnergyGuidanceConfig,
|
||||||
apply_layer_skip,
|
|
||||||
apply_faster_cache,
|
apply_faster_cache,
|
||||||
|
apply_layer_skip,
|
||||||
apply_pyramid_attention_broadcast,
|
apply_pyramid_attention_broadcast,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -1085,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionSAGPipeline,
|
StableDiffusionSAGPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
StableDiffusionXLAdapterPipeline,
|
StableDiffusionXLAdapterPipeline,
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||||
StableDiffusionXLControlNetInpaintPipeline,
|
StableDiffusionXLControlNetInpaintPipeline,
|
||||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||||
@@ -1102,7 +1103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLPAGInpaintPipeline,
|
StableDiffusionXLPAGInpaintPipeline,
|
||||||
StableDiffusionXLPAGPipeline,
|
StableDiffusionXLPAGPipeline,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
StableUnCLIPImg2ImgPipeline,
|
StableUnCLIPImg2ImgPipeline,
|
||||||
StableUnCLIPPipeline,
|
StableUnCLIPPipeline,
|
||||||
StableVideoDiffusionPipeline,
|
StableVideoDiffusionPipeline,
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
|
|||||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
|
|||||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
|
|||||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,12 @@ import torch
|
|||||||
|
|
||||||
from ..utils import get_logger
|
from ..utils import get_logger
|
||||||
from ..utils.torch_utils import unwrap_module
|
from ..utils.torch_utils import unwrap_module
|
||||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
|
from ._common import (
|
||||||
|
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||||
|
_ATTENTION_CLASSES,
|
||||||
|
_FEEDFORWARD_CLASSES,
|
||||||
|
_get_submodule_from_fqn,
|
||||||
|
)
|
||||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||||
from .hooks import HookRegistry, ModelHook
|
from .hooks import HookRegistry, ModelHook
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|||||||
@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .ip_adapter import (
|
from .ip_adapter import (
|
||||||
FluxIPAdapterMixin,
|
FluxIPAdapterMixin,
|
||||||
IPAdapterMixin,
|
IPAdapterMixin,
|
||||||
SD3IPAdapterMixin,
|
|
||||||
ModularIPAdapterMixin,
|
ModularIPAdapterMixin,
|
||||||
|
SD3IPAdapterMixin,
|
||||||
)
|
)
|
||||||
from .lora_pipeline import (
|
from .lora_pipeline import (
|
||||||
AmusedLoraLoaderMixin,
|
AmusedLoraLoaderMixin,
|
||||||
|
|||||||
@@ -703,12 +703,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||||
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||||
from .stable_diffusion_xl import (
|
from .stable_diffusion_xl import (
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
StableDiffusionXLImg2ImgPipeline,
|
StableDiffusionXLImg2ImgPipeline,
|
||||||
StableDiffusionXLInpaintPipeline,
|
StableDiffusionXLInpaintPipeline,
|
||||||
StableDiffusionXLInstructPix2PixPipeline,
|
StableDiffusionXLInstructPix2PixPipeline,
|
||||||
StableDiffusionXLModularLoader,
|
StableDiffusionXLModularLoader,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
)
|
)
|
||||||
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
||||||
from .t2i_adapter import (
|
from .t2i_adapter import (
|
||||||
|
|||||||
@@ -12,21 +12,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import Any, Dict, List, Optional, Union
|
||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..models.modeling_utils import ModelMixin
|
|
||||||
from .modular_pipeline_utils import ComponentSpec
|
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
@@ -231,8 +228,9 @@ class AutoOffloadStrategy:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
from .modular_pipeline_utils import ComponentSpec
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
class ComponentsManager:
|
class ComponentsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.components = OrderedDict()
|
self.components = OrderedDict()
|
||||||
@@ -726,7 +724,7 @@ class ComponentsManager:
|
|||||||
if info.get("adapters") is not None:
|
if info.get("adapters") is not None:
|
||||||
output += f" Adapters: {info['adapters']}\n"
|
output += f" Adapters: {info['adapters']}\n"
|
||||||
if info.get("ip_adapter"):
|
if info.get("ip_adapter"):
|
||||||
output += f" IP-Adapter: Enabled\n"
|
output += " IP-Adapter: Enabled\n"
|
||||||
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -12,29 +12,27 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Tuple, Union, Optional, Type
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
from huggingface_hub.utils import validate_hf_hub_args
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
is_accelerate_available,
|
|
||||||
is_accelerate_version,
|
|
||||||
logging,
|
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
is_accelerate_available,
|
||||||
|
logging,
|
||||||
)
|
)
|
||||||
from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple
|
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||||
|
from .components_manager import ComponentsManager
|
||||||
from .modular_pipeline_utils import (
|
from .modular_pipeline_utils import (
|
||||||
ComponentSpec,
|
ComponentSpec,
|
||||||
ConfigSpec,
|
ConfigSpec,
|
||||||
@@ -42,18 +40,15 @@ from .modular_pipeline_utils import (
|
|||||||
OutputParam,
|
OutputParam,
|
||||||
format_components,
|
format_components,
|
||||||
format_configs,
|
format_configs,
|
||||||
format_input_params,
|
|
||||||
format_inputs_short,
|
format_inputs_short,
|
||||||
format_intermediates_short,
|
format_intermediates_short,
|
||||||
format_output_params,
|
|
||||||
format_params,
|
|
||||||
make_doc_string,
|
make_doc_string,
|
||||||
)
|
)
|
||||||
from .components_manager import ComponentsManager
|
from .pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
|
||||||
|
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
pass
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -108,10 +103,7 @@ class PipelineState:
|
|||||||
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
|
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"PipelineState(\n"
|
f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" f" intermediates={{\n{intermediates}\n }}\n" f")"
|
||||||
f" inputs={{\n{inputs}\n }},\n"
|
|
||||||
f" intermediates={{\n{intermediates}\n }}\n"
|
|
||||||
f")"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -120,6 +112,7 @@ class BlockState:
|
|||||||
"""
|
"""
|
||||||
Container for block state data with attribute access and formatted representation.
|
Container for block state data with attribute access and formatted representation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
@@ -158,20 +151,66 @@ class BlockState:
|
|||||||
return f"BlockState(\n{attributes}\n)"
|
return f"BlockState(\n{attributes}\n)"
|
||||||
|
|
||||||
|
|
||||||
|
class ModularPipelineMixin(ConfigMixin):
|
||||||
class ModularPipelineMixin:
|
|
||||||
"""
|
"""
|
||||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config_name = "config.json"
|
||||||
|
|
||||||
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
trust_remote_code: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
hub_kwargs_names = [
|
||||||
|
"cache_dir",
|
||||||
|
"force_download",
|
||||||
|
"local_files_only",
|
||||||
|
"proxies",
|
||||||
|
"resume_download",
|
||||||
|
"revision",
|
||||||
|
"subfolder",
|
||||||
|
"token",
|
||||||
|
]
|
||||||
|
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||||
|
|
||||||
|
config = cls.load_config(pretrained_model_name_or_path)
|
||||||
|
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||||
|
trust_remote_code = resolve_trust_remote_code(
|
||||||
|
trust_remote_code, pretrained_model_name_or_path, False, has_remote_code
|
||||||
|
)
|
||||||
|
if not (has_remote_code and trust_remote_code):
|
||||||
|
raise ValueError("")
|
||||||
|
|
||||||
|
class_ref = config["auto_map"][cls.__name__]
|
||||||
|
module_file, class_name = class_ref.split(".")
|
||||||
|
module_file = module_file + ".py"
|
||||||
|
block_cls = get_class_from_dynamic_module(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
module_file=module_file,
|
||||||
|
class_name=class_name,
|
||||||
|
is_modular=True,
|
||||||
|
**hub_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return block_cls()
|
||||||
|
|
||||||
|
def setup_loader(
|
||||||
|
self,
|
||||||
|
modular_repo: Optional[Union[str, os.PathLike]] = None,
|
||||||
|
component_manager: Optional[ComponentsManager] = None,
|
||||||
|
collection: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
create a mouldar loader, optionally accept modular_repo to load from hub.
|
create a ModularLoader, optionally accept modular_repo to load from hub.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Import components loader (it is model-specific class)
|
# Import components loader (it is model-specific class)
|
||||||
loader_class_name = MODULAR_LOADER_MAPPING[self.model_name]
|
loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
|
||||||
|
|
||||||
diffusers_module = importlib.import_module("diffusers")
|
diffusers_module = importlib.import_module("diffusers")
|
||||||
loader_class = getattr(diffusers_module, loader_class_name)
|
loader_class = getattr(diffusers_module, loader_class_name)
|
||||||
|
|
||||||
@@ -181,8 +220,9 @@ class ModularPipelineMixin:
|
|||||||
# Create the loader with the updated specs
|
# Create the loader with the updated specs
|
||||||
specs = component_specs + config_specs
|
specs = component_specs + config_specs
|
||||||
|
|
||||||
self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
|
self.loader = loader_class(
|
||||||
|
specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_call_parameters(self) -> Dict[str, Any]:
|
def default_call_parameters(self) -> Dict[str, Any]:
|
||||||
@@ -238,7 +278,6 @@ class ModularPipelineMixin:
|
|||||||
if output is None:
|
if output is None:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
elif isinstance(output, str):
|
elif isinstance(output, str):
|
||||||
return state.get_intermediate(output)
|
return state.get_intermediate(output)
|
||||||
|
|
||||||
@@ -268,7 +307,6 @@ class ModularPipelineMixin:
|
|||||||
|
|
||||||
|
|
||||||
class PipelineBlock(ModularPipelineMixin):
|
class PipelineBlock(ModularPipelineMixin):
|
||||||
|
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -284,7 +322,6 @@ class PipelineBlock(ModularPipelineMixin):
|
|||||||
def expected_configs(self) -> List[ConfigSpec]:
|
def expected_configs(self) -> List[ConfigSpec]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
|
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> List[InputParam]:
|
def inputs(self) -> List[InputParam]:
|
||||||
@@ -322,7 +359,6 @@ class PipelineBlock(ModularPipelineMixin):
|
|||||||
input_names.append(input_param.name)
|
input_names.append(input_param.name)
|
||||||
return input_names
|
return input_names
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||||
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
||||||
|
|
||||||
@@ -331,14 +367,14 @@ class PipelineBlock(ModularPipelineMixin):
|
|||||||
base_class = self.__class__.__bases__[0].__name__
|
base_class = self.__class__.__bases__[0].__name__
|
||||||
|
|
||||||
# Format description with proper indentation
|
# Format description with proper indentation
|
||||||
desc_lines = self.description.split('\n')
|
desc_lines = self.description.split("\n")
|
||||||
desc = []
|
desc = []
|
||||||
# First line with "Description:" label
|
# First line with "Description:" label
|
||||||
desc.append(f" Description: {desc_lines[0]}")
|
desc.append(f" Description: {desc_lines[0]}")
|
||||||
# Subsequent lines with proper indentation
|
# Subsequent lines with proper indentation
|
||||||
if len(desc_lines) > 1:
|
if len(desc_lines) > 1:
|
||||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||||
desc = '\n'.join(desc) + '\n'
|
desc = "\n".join(desc) + "\n"
|
||||||
|
|
||||||
# Components section - use format_components with add_empty_lines=False
|
# Components section - use format_components with add_empty_lines=False
|
||||||
expected_components = getattr(self, "expected_components", [])
|
expected_components = getattr(self, "expected_components", [])
|
||||||
@@ -355,7 +391,9 @@ class PipelineBlock(ModularPipelineMixin):
|
|||||||
inputs = "Inputs:\n " + inputs_str
|
inputs = "Inputs:\n " + inputs_str
|
||||||
|
|
||||||
# Intermediates section
|
# Intermediates section
|
||||||
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
|
intermediates_str = format_intermediates_short(
|
||||||
|
self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs
|
||||||
|
)
|
||||||
intermediates = f"Intermediates:\n{intermediates_str}"
|
intermediates = f"Intermediates:\n{intermediates_str}"
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -369,7 +407,6 @@ class PipelineBlock(ModularPipelineMixin):
|
|||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self):
|
def doc(self):
|
||||||
return make_doc_string(
|
return make_doc_string(
|
||||||
@@ -379,10 +416,9 @@ class PipelineBlock(ModularPipelineMixin):
|
|||||||
self.description,
|
self.description,
|
||||||
class_name=self.__class__.__name__,
|
class_name=self.__class__.__name__,
|
||||||
expected_components=self.expected_components,
|
expected_components=self.expected_components,
|
||||||
expected_configs=self.expected_configs
|
expected_configs=self.expected_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_block_state(self, state: PipelineState) -> dict:
|
def get_block_state(self, state: PipelineState) -> dict:
|
||||||
"""Get all inputs and intermediates in one dictionary"""
|
"""Get all inputs and intermediates in one dictionary"""
|
||||||
data = {}
|
data = {}
|
||||||
@@ -429,9 +465,11 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
|
|||||||
for input_param in inputs:
|
for input_param in inputs:
|
||||||
if input_param.name in combined_dict:
|
if input_param.name in combined_dict:
|
||||||
current_param = combined_dict[input_param.name]
|
current_param = combined_dict[input_param.name]
|
||||||
if (current_param.default is not None and
|
if (
|
||||||
input_param.default is not None and
|
current_param.default is not None
|
||||||
current_param.default != input_param.default):
|
and input_param.default is not None
|
||||||
|
and current_param.default != input_param.default
|
||||||
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Multiple different default values found for input '{input_param.name}': "
|
f"Multiple different default values found for input '{input_param.name}': "
|
||||||
f"{current_param.default} (from block '{value_sources[input_param.name]}') and "
|
f"{current_param.default} (from block '{value_sources[input_param.name]}') and "
|
||||||
@@ -446,6 +484,7 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
|
|||||||
|
|
||||||
return list(combined_dict.values())
|
return list(combined_dict.values())
|
||||||
|
|
||||||
|
|
||||||
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
|
def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
|
||||||
"""
|
"""
|
||||||
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs,
|
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs,
|
||||||
@@ -487,15 +526,15 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
blocks[block_name] = block_cls()
|
blocks[block_name] = block_cls()
|
||||||
self.blocks = blocks
|
self.blocks = blocks
|
||||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||||
raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.")
|
raise ValueError(
|
||||||
|
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||||
|
)
|
||||||
default_blocks = [t for t in self.block_trigger_inputs if t is None]
|
default_blocks = [t for t in self.block_trigger_inputs if t is None]
|
||||||
# can only have 1 or 0 default block, and has to put in the last
|
# can only have 1 or 0 default block, and has to put in the last
|
||||||
# the order of blocksmatters here because the first block with matching trigger will be dispatched
|
# the order of blocksmatters here because the first block with matching trigger will be dispatched
|
||||||
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
|
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
|
||||||
# if both mask and image are provided, it is inpaint; if only image is provided, it is img2img
|
# if both mask and image are provided, it is inpaint; if only image is provided, it is img2img
|
||||||
if len(default_blocks) > 1 or (
|
if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
|
||||||
len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
|
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
|
||||||
"in block_trigger_inputs."
|
"in block_trigger_inputs."
|
||||||
@@ -532,7 +571,6 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
expected_configs.append(config)
|
expected_configs.append(config)
|
||||||
return expected_configs
|
return expected_configs
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_inputs(self) -> List[str]:
|
def required_inputs(self) -> List[str]:
|
||||||
first_block = next(iter(self.blocks.values()))
|
first_block = next(iter(self.blocks.values()))
|
||||||
@@ -557,7 +595,6 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
|
|
||||||
return list(required_by_all)
|
return list(required_by_all)
|
||||||
|
|
||||||
|
|
||||||
# YiYi TODO: add test for this
|
# YiYi TODO: add test for this
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> List[Tuple[str, Any]]:
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
@@ -571,7 +608,6 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
input_param.required = False
|
input_param.required = False
|
||||||
return combined_inputs
|
return combined_inputs
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def intermediates_inputs(self) -> List[str]:
|
def intermediates_inputs(self) -> List[str]:
|
||||||
named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()]
|
named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()]
|
||||||
@@ -630,18 +666,19 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
Returns a set of all unique trigger input values found in the blocks.
|
Returns a set of all unique trigger input values found in the blocks.
|
||||||
Returns: Set[str] containing all unique block_trigger_inputs values
|
Returns: Set[str] containing all unique block_trigger_inputs values
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fn_recursive_get_trigger(blocks):
|
def fn_recursive_get_trigger(blocks):
|
||||||
trigger_values = set()
|
trigger_values = set()
|
||||||
|
|
||||||
if blocks is not None:
|
if blocks is not None:
|
||||||
for name, block in blocks.items():
|
for name, block in blocks.items():
|
||||||
# Check if current block has trigger inputs(i.e. auto block)
|
# Check if current block has trigger inputs(i.e. auto block)
|
||||||
if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None:
|
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||||
# Add all non-None values from the trigger inputs list
|
# Add all non-None values from the trigger inputs list
|
||||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||||
|
|
||||||
# If block has blocks, recursively check them
|
# If block has blocks, recursively check them
|
||||||
if hasattr(block, 'blocks'):
|
if hasattr(block, "blocks"):
|
||||||
nested_triggers = fn_recursive_get_trigger(block.blocks)
|
nested_triggers = fn_recursive_get_trigger(block.blocks)
|
||||||
trigger_values.update(nested_triggers)
|
trigger_values.update(nested_triggers)
|
||||||
|
|
||||||
@@ -660,12 +697,9 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
base_class = self.__class__.__bases__[0].__name__
|
base_class = self.__class__.__bases__[0].__name__
|
||||||
header = (
|
header = (
|
||||||
f"{class_name}(\n Class: {base_class}\n"
|
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||||
if base_class and base_class != "object"
|
|
||||||
else f"{class_name}(\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.trigger_inputs:
|
if self.trigger_inputs:
|
||||||
header += "\n"
|
header += "\n"
|
||||||
header += " " + "=" * 100 + "\n"
|
header += " " + "=" * 100 + "\n"
|
||||||
@@ -677,14 +711,14 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
header += " " + "=" * 100 + "\n\n"
|
header += " " + "=" * 100 + "\n\n"
|
||||||
|
|
||||||
# Format description with proper indentation
|
# Format description with proper indentation
|
||||||
desc_lines = self.description.split('\n')
|
desc_lines = self.description.split("\n")
|
||||||
desc = []
|
desc = []
|
||||||
# First line with "Description:" label
|
# First line with "Description:" label
|
||||||
desc.append(f" Description: {desc_lines[0]}")
|
desc.append(f" Description: {desc_lines[0]}")
|
||||||
# Subsequent lines with proper indentation
|
# Subsequent lines with proper indentation
|
||||||
if len(desc_lines) > 1:
|
if len(desc_lines) > 1:
|
||||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||||
desc = '\n'.join(desc) + '\n'
|
desc = "\n".join(desc) + "\n"
|
||||||
|
|
||||||
# Components section - focus only on expected components
|
# Components section - focus only on expected components
|
||||||
expected_components = getattr(self, "expected_components", [])
|
expected_components = getattr(self, "expected_components", [])
|
||||||
@@ -699,7 +733,7 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
for i, (name, block) in enumerate(self.blocks.items()):
|
for i, (name, block) in enumerate(self.blocks.items()):
|
||||||
# Get trigger input for this block
|
# Get trigger input for this block
|
||||||
trigger = None
|
trigger = None
|
||||||
if hasattr(self, 'block_to_trigger_map'):
|
if hasattr(self, "block_to_trigger_map"):
|
||||||
trigger = self.block_to_trigger_map.get(name)
|
trigger = self.block_to_trigger_map.get(name)
|
||||||
# Format the trigger info
|
# Format the trigger info
|
||||||
if trigger is None:
|
if trigger is None:
|
||||||
@@ -715,21 +749,13 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||||
|
|
||||||
# Add block description
|
# Add block description
|
||||||
desc_lines = block.description.split('\n')
|
desc_lines = block.description.split("\n")
|
||||||
indented_desc = desc_lines[0]
|
indented_desc = desc_lines[0]
|
||||||
if len(desc_lines) > 1:
|
if len(desc_lines) > 1:
|
||||||
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
|
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
|
||||||
blocks_str += f" Description: {indented_desc}\n\n"
|
blocks_str += f" Description: {indented_desc}\n\n"
|
||||||
|
|
||||||
return (
|
return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")"
|
||||||
f"{header}\n"
|
|
||||||
f"{desc}\n\n"
|
|
||||||
f"{components_str}\n\n"
|
|
||||||
f"{configs_str}\n\n"
|
|
||||||
f"{blocks_str}"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self):
|
def doc(self):
|
||||||
@@ -740,13 +766,15 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
|||||||
self.description,
|
self.description,
|
||||||
class_name=self.__class__.__name__,
|
class_name=self.__class__.__name__,
|
||||||
expected_components=self.expected_components,
|
expected_components=self.expected_components,
|
||||||
expected_configs=self.expected_configs
|
expected_configs=self.expected_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SequentialPipelineBlocks(ModularPipelineMixin):
|
class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||||
"""
|
"""
|
||||||
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
|
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
block_classes = []
|
block_classes = []
|
||||||
block_names = []
|
block_names = []
|
||||||
|
|
||||||
@@ -798,7 +826,6 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
blocks[block_name] = block_cls()
|
blocks[block_name] = block_cls()
|
||||||
self.blocks = blocks
|
self.blocks = blocks
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_inputs(self) -> List[str]:
|
def required_inputs(self) -> List[str]:
|
||||||
# Get the first block from the dictionary
|
# Get the first block from the dictionary
|
||||||
@@ -884,18 +911,19 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
Returns a set of all unique trigger input values found in the blocks.
|
Returns a set of all unique trigger input values found in the blocks.
|
||||||
Returns: Set[str] containing all unique block_trigger_inputs values
|
Returns: Set[str] containing all unique block_trigger_inputs values
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fn_recursive_get_trigger(blocks):
|
def fn_recursive_get_trigger(blocks):
|
||||||
trigger_values = set()
|
trigger_values = set()
|
||||||
|
|
||||||
if blocks is not None:
|
if blocks is not None:
|
||||||
for name, block in blocks.items():
|
for name, block in blocks.items():
|
||||||
# Check if current block has trigger inputs(i.e. auto block)
|
# Check if current block has trigger inputs(i.e. auto block)
|
||||||
if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None:
|
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||||
# Add all non-None values from the trigger inputs list
|
# Add all non-None values from the trigger inputs list
|
||||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||||
|
|
||||||
# If block has blocks, recursively check them
|
# If block has blocks, recursively check them
|
||||||
if hasattr(block, 'blocks'):
|
if hasattr(block, "blocks"):
|
||||||
nested_triggers = fn_recursive_get_trigger(block.blocks)
|
nested_triggers = fn_recursive_get_trigger(block.blocks)
|
||||||
trigger_values.update(nested_triggers)
|
trigger_values.update(nested_triggers)
|
||||||
|
|
||||||
@@ -915,8 +943,8 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
result_blocks = OrderedDict()
|
result_blocks = OrderedDict()
|
||||||
|
|
||||||
# sequential or PipelineBlock
|
# sequential or PipelineBlock
|
||||||
if not hasattr(block, 'block_trigger_inputs'):
|
if not hasattr(block, "block_trigger_inputs"):
|
||||||
if hasattr(block, 'blocks'):
|
if hasattr(block, "blocks"):
|
||||||
# sequential
|
# sequential
|
||||||
for block_name, block in block.blocks.items():
|
for block_name, block in block.blocks.items():
|
||||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
|
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
|
||||||
@@ -925,7 +953,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
# PipelineBlock
|
# PipelineBlock
|
||||||
result_blocks[block_name] = block
|
result_blocks[block_name] = block
|
||||||
# Add this block's output names to active triggers if defined
|
# Add this block's output names to active triggers if defined
|
||||||
if hasattr(block, 'outputs'):
|
if hasattr(block, "outputs"):
|
||||||
active_triggers.update(out.name for out in block.outputs)
|
active_triggers.update(out.name for out in block.outputs)
|
||||||
return result_blocks
|
return result_blocks
|
||||||
|
|
||||||
@@ -947,13 +975,13 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
|
|
||||||
if this_block is not None:
|
if this_block is not None:
|
||||||
# sequential/auto
|
# sequential/auto
|
||||||
if hasattr(this_block, 'blocks'):
|
if hasattr(this_block, "blocks"):
|
||||||
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
|
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
|
||||||
else:
|
else:
|
||||||
# PipelineBlock
|
# PipelineBlock
|
||||||
result_blocks[block_name] = this_block
|
result_blocks[block_name] = this_block
|
||||||
# Add this block's output names to active triggers if defined
|
# Add this block's output names to active triggers if defined
|
||||||
if hasattr(this_block, 'outputs'):
|
if hasattr(this_block, "outputs"):
|
||||||
active_triggers.update(out.name for out in this_block.outputs)
|
active_triggers.update(out.name for out in this_block.outputs)
|
||||||
|
|
||||||
return result_blocks
|
return result_blocks
|
||||||
@@ -968,7 +996,6 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
trigger_inputs_all = self.trigger_inputs
|
trigger_inputs_all = self.trigger_inputs
|
||||||
|
|
||||||
if trigger_inputs is not None:
|
if trigger_inputs is not None:
|
||||||
|
|
||||||
if not isinstance(trigger_inputs, (list, tuple, set)):
|
if not isinstance(trigger_inputs, (list, tuple, set)):
|
||||||
trigger_inputs = [trigger_inputs]
|
trigger_inputs = [trigger_inputs]
|
||||||
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
|
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
|
||||||
@@ -990,12 +1017,9 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
base_class = self.__class__.__bases__[0].__name__
|
base_class = self.__class__.__bases__[0].__name__
|
||||||
header = (
|
header = (
|
||||||
f"{class_name}(\n Class: {base_class}\n"
|
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||||
if base_class and base_class != "object"
|
|
||||||
else f"{class_name}(\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.trigger_inputs:
|
if self.trigger_inputs:
|
||||||
header += "\n"
|
header += "\n"
|
||||||
header += " " + "=" * 100 + "\n"
|
header += " " + "=" * 100 + "\n"
|
||||||
@@ -1007,14 +1031,14 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
header += " " + "=" * 100 + "\n\n"
|
header += " " + "=" * 100 + "\n\n"
|
||||||
|
|
||||||
# Format description with proper indentation
|
# Format description with proper indentation
|
||||||
desc_lines = self.description.split('\n')
|
desc_lines = self.description.split("\n")
|
||||||
desc = []
|
desc = []
|
||||||
# First line with "Description:" label
|
# First line with "Description:" label
|
||||||
desc.append(f" Description: {desc_lines[0]}")
|
desc.append(f" Description: {desc_lines[0]}")
|
||||||
# Subsequent lines with proper indentation
|
# Subsequent lines with proper indentation
|
||||||
if len(desc_lines) > 1:
|
if len(desc_lines) > 1:
|
||||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||||
desc = '\n'.join(desc) + '\n'
|
desc = "\n".join(desc) + "\n"
|
||||||
|
|
||||||
# Components section - focus only on expected components
|
# Components section - focus only on expected components
|
||||||
expected_components = getattr(self, "expected_components", [])
|
expected_components = getattr(self, "expected_components", [])
|
||||||
@@ -1029,7 +1053,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
for i, (name, block) in enumerate(self.blocks.items()):
|
for i, (name, block) in enumerate(self.blocks.items()):
|
||||||
# Get trigger input for this block
|
# Get trigger input for this block
|
||||||
trigger = None
|
trigger = None
|
||||||
if hasattr(self, 'block_to_trigger_map'):
|
if hasattr(self, "block_to_trigger_map"):
|
||||||
trigger = self.block_to_trigger_map.get(name)
|
trigger = self.block_to_trigger_map.get(name)
|
||||||
# Format the trigger info
|
# Format the trigger info
|
||||||
if trigger is None:
|
if trigger is None:
|
||||||
@@ -1045,21 +1069,13 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||||
|
|
||||||
# Add block description
|
# Add block description
|
||||||
desc_lines = block.description.split('\n')
|
desc_lines = block.description.split("\n")
|
||||||
indented_desc = desc_lines[0]
|
indented_desc = desc_lines[0]
|
||||||
if len(desc_lines) > 1:
|
if len(desc_lines) > 1:
|
||||||
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
|
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
|
||||||
blocks_str += f" Description: {indented_desc}\n\n"
|
blocks_str += f" Description: {indented_desc}\n\n"
|
||||||
|
|
||||||
return (
|
return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")"
|
||||||
f"{header}\n"
|
|
||||||
f"{desc}\n\n"
|
|
||||||
f"{components_str}\n\n"
|
|
||||||
f"{configs_str}\n\n"
|
|
||||||
f"{blocks_str}"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self):
|
def doc(self):
|
||||||
@@ -1070,11 +1086,10 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
|||||||
self.description,
|
self.description,
|
||||||
class_name=self.__class__.__name__,
|
class_name=self.__class__.__name__,
|
||||||
expected_components=self.expected_components,
|
expected_components=self.expected_components,
|
||||||
expected_configs=self.expected_configs
|
expected_configs=self.expected_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# YiYi TODO:
|
# YiYi TODO:
|
||||||
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
|
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
|
||||||
# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader
|
# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader
|
||||||
@@ -1084,8 +1099,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
Base class for all Modular pipelines loaders.
|
Base class for all Modular pipelines loaders.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_name = "modular_model_index.json"
|
|
||||||
|
|
||||||
|
config_name = "modular_model_index.json"
|
||||||
|
|
||||||
def register_components(self, **kwargs):
|
def register_components(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -1097,7 +1112,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
for name, module in kwargs.items():
|
for name, module in kwargs.items():
|
||||||
|
|
||||||
# current component spec
|
# current component spec
|
||||||
component_spec = self._component_specs.get(name)
|
component_spec = self._component_specs.get(name)
|
||||||
if component_spec is None:
|
if component_spec is None:
|
||||||
@@ -1107,7 +1121,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
is_registered = hasattr(self, name)
|
is_registered = hasattr(self, name)
|
||||||
|
|
||||||
if module is not None and not hasattr(module, "_diffusers_load_id"):
|
if module is not None and not hasattr(module, "_diffusers_load_id"):
|
||||||
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
|
raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.")
|
||||||
|
|
||||||
# actual library and class name of the module
|
# actual library and class name of the module
|
||||||
|
|
||||||
@@ -1143,12 +1157,20 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
current_module = getattr(self, name, None)
|
current_module = getattr(self, name, None)
|
||||||
# skip if the component is already registered with the same object
|
# skip if the component is already registered with the same object
|
||||||
if current_module is module:
|
if current_module is module:
|
||||||
logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping")
|
logger.info(
|
||||||
|
f"ModularLoader.register_components: {name} is already registered with same object, skipping"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# it module is not an instance of the expected type, still register it but with a warning
|
# it module is not an instance of the expected type, still register it but with a warning
|
||||||
if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint):
|
if (
|
||||||
logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}")
|
module is not None
|
||||||
|
and component_spec.type_hint is not None
|
||||||
|
and not isinstance(module, component_spec.type_hint)
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
# warn if unregister
|
# warn if unregister
|
||||||
if current_module is not None and module is None:
|
if current_module is not None and module is None:
|
||||||
@@ -1157,10 +1179,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
f"(was {current_module.__class__.__name__})"
|
f"(was {current_module.__class__.__name__})"
|
||||||
)
|
)
|
||||||
# same type, new instance → debug
|
# same type, new instance → debug
|
||||||
elif current_module is not None \
|
elif (
|
||||||
and module is not None \
|
current_module is not None
|
||||||
and isinstance(module, current_module.__class__) \
|
and module is not None
|
||||||
and current_module != module:
|
and isinstance(module, current_module.__class__)
|
||||||
|
and current_module != module
|
||||||
|
):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ModularLoader.register_components: replacing existing '{name}' "
|
f"ModularLoader.register_components: replacing existing '{name}' "
|
||||||
f"(same type {type(current_module).__name__}, new instance)"
|
f"(same type {type(current_module).__name__}, new instance)"
|
||||||
@@ -1175,28 +1199,34 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
if module is not None and self._component_manager is not None:
|
if module is not None and self._component_manager is not None:
|
||||||
self._component_manager.add(name, module, self._collection)
|
self._component_manager.add(name, module, self._collection)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
|
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
|
||||||
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
specs: List[Union[ComponentSpec, ConfigSpec]],
|
||||||
|
modular_repo: Optional[str] = None,
|
||||||
|
component_manager: Optional[ComponentsManager] = None,
|
||||||
|
collection: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the loader with a list of component specs and config specs.
|
Initialize the loader with a list of component specs and config specs.
|
||||||
"""
|
"""
|
||||||
self._component_manager = component_manager
|
self._component_manager = component_manager
|
||||||
self._collection = collection
|
self._collection = collection
|
||||||
self._component_specs = {
|
self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)}
|
||||||
spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)
|
self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)}
|
||||||
}
|
|
||||||
self._config_specs = {
|
|
||||||
spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)
|
|
||||||
}
|
|
||||||
|
|
||||||
# update component_specs and config_specs from modular_repo
|
# update component_specs and config_specs from modular_repo
|
||||||
if modular_repo is not None:
|
if modular_repo is not None:
|
||||||
config_dict = self.load_config(modular_repo, **kwargs)
|
config_dict = self.load_config(modular_repo, **kwargs)
|
||||||
|
|
||||||
for name, value in config_dict.items():
|
for name, value in config_dict.items():
|
||||||
if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3:
|
if (
|
||||||
|
name in self._component_specs
|
||||||
|
and self._component_specs[name].default_creation_method == "from_pretrained"
|
||||||
|
and isinstance(value, (tuple, list))
|
||||||
|
and len(value) == 3
|
||||||
|
):
|
||||||
library, class_name, component_spec_dict = value
|
library, class_name, component_spec_dict = value
|
||||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||||
self._component_specs[name] = component_spec
|
self._component_specs[name] = component_spec
|
||||||
@@ -1214,7 +1244,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
default_configs[name] = config_spec.default
|
default_configs[name] = config_spec.default
|
||||||
self.register_to_config(**default_configs)
|
self.register_to_config(**default_configs)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
r"""
|
r"""
|
||||||
@@ -1280,15 +1309,10 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def components(self) -> Dict[str, Any]:
|
def components(self) -> Dict[str, Any]:
|
||||||
# return only components we've actually set as attributes on self
|
# return only components we've actually set as attributes on self
|
||||||
return {
|
return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
|
||||||
name: getattr(self, name)
|
|
||||||
for name in self._component_specs.keys()
|
|
||||||
if hasattr(self, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def update(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -1340,24 +1364,20 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
for name, component in passed_components.items():
|
for name, component in passed_components.items():
|
||||||
if not hasattr(component, "_diffusers_load_id"):
|
if not hasattr(component, "_diffusers_load_id"):
|
||||||
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
|
raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.")
|
||||||
|
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||||
|
|
||||||
|
|
||||||
self.register_components(**passed_components)
|
self.register_components(**passed_components)
|
||||||
|
|
||||||
|
|
||||||
config_to_register = {}
|
config_to_register = {}
|
||||||
for name, new_value in passed_config_values.items():
|
for name, new_value in passed_config_values.items():
|
||||||
|
|
||||||
# e.g. requires_aesthetics_score = False
|
# e.g. requires_aesthetics_score = False
|
||||||
self._config_specs[name].default = new_value
|
self._config_specs[name].default = new_value
|
||||||
config_to_register[name] = new_value
|
config_to_register[name] = new_value
|
||||||
self.register_to_config(**config_to_register)
|
self.register_to_config(**config_to_register)
|
||||||
|
|
||||||
|
|
||||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||||
def load(self, component_names: Optional[List[str]] = None, **kwargs):
|
def load(self, component_names: Optional[List[str]] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -1410,8 +1430,9 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
# YiYi TODO:
|
# YiYi TODO:
|
||||||
# 1. should support save some components too! currently only modular_model_index.json is saved
|
# 1. should support save some components too! currently only modular_model_index.json is saved
|
||||||
# 2. maybe order the json file to make it more readable: configs first, then components
|
# 2. maybe order the json file to make it more readable: configs first, then components
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs):
|
def save_pretrained(
|
||||||
|
self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs
|
||||||
|
):
|
||||||
component_names = list(self._component_specs.keys())
|
component_names = list(self._component_specs.keys())
|
||||||
config_names = list(self._config_specs.keys())
|
config_names = list(self._config_specs.keys())
|
||||||
self.register_to_config(_components_names=component_names, _configs_names=config_names)
|
self.register_to_config(_components_names=component_names, _configs_names=config_names)
|
||||||
@@ -1421,11 +1442,11 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
config.pop("_configs_names", None)
|
config.pop("_configs_names", None)
|
||||||
self._internal_dict = FrozenDict(config)
|
self._internal_dict = FrozenDict(config)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@validate_hf_hub_args
|
@validate_hf_hub_args
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs):
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs
|
||||||
|
):
|
||||||
config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
||||||
expected_component = set(config_dict.pop("_components_names"))
|
expected_component = set(config_dict.pop("_components_names"))
|
||||||
expected_config = set(config_dict.pop("_configs_names"))
|
expected_config = set(config_dict.pop("_configs_names"))
|
||||||
@@ -1450,7 +1471,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
|||||||
component_specs.append(ComponentSpec(name=name, default_creation_method="from_config"))
|
component_specs.append(ComponentSpec(name=name, default_creation_method="from_config"))
|
||||||
return cls(component_specs + config_specs)
|
return cls(component_specs + config_specs)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
|
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -12,13 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass, asdict, field, fields
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||||
from ..utils.import_utils import is_torch_available
|
from ..utils.import_utils import is_torch_available
|
||||||
from ..configuration_utils import FrozenDict, ConfigMixin
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -176,7 +177,7 @@ class ComponentSpec:
|
|||||||
|
|
||||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`type_hint` is required when using from_config creation method."
|
"`type_hint` is required when using from_config creation method."
|
||||||
)
|
)
|
||||||
|
|
||||||
config = config or self.config or {}
|
config = config or self.config or {}
|
||||||
@@ -209,7 +210,7 @@ class ComponentSpec:
|
|||||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||||
repo = load_kwargs.pop("repo", None)
|
repo = load_kwargs.pop("repo", None)
|
||||||
if repo is None:
|
if repo is None:
|
||||||
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
raise ValueError("`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||||
|
|
||||||
if self.type_hint is None:
|
if self.type_hint is None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -334,6 +334,7 @@ def maybe_raise_or_warn(
|
|||||||
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
|
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
|
||||||
def simple_get_class_obj(library_name, class_name):
|
def simple_get_class_obj(library_name, class_name):
|
||||||
from diffusers import pipelines
|
from diffusers import pipelines
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
|
|
||||||
if is_pipeline_module:
|
if is_pipeline_module:
|
||||||
@@ -345,6 +346,7 @@ def simple_get_class_obj(library_name, class_name):
|
|||||||
|
|
||||||
return class_obj
|
return class_obj
|
||||||
|
|
||||||
|
|
||||||
def get_class_obj_and_candidates(
|
def get_class_obj_and_candidates(
|
||||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
||||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||||
from .pipeline_stable_diffusion_xl_modular import (
|
from .pipeline_stable_diffusion_xl_modular import (
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
StableDiffusionXLControlNetDenoiseStep,
|
StableDiffusionXLControlNetDenoiseStep,
|
||||||
StableDiffusionXLDecodeLatentsStep,
|
StableDiffusionXLDecodeLatentsStep,
|
||||||
StableDiffusionXLDenoiseStep,
|
StableDiffusionXLDenoiseStep,
|
||||||
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLPrepareLatentsStep,
|
StableDiffusionXLPrepareLatentsStep,
|
||||||
StableDiffusionXLSetTimestepsStep,
|
StableDiffusionXLSetTimestepsStep,
|
||||||
StableDiffusionXLTextEncoderStep,
|
StableDiffusionXLTextEncoderStep,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
+22
-24
@@ -13,17 +13,28 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, List, Optional, Tuple, Union, Dict
|
from collections import OrderedDict
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from collections import OrderedDict
|
from transformers import (
|
||||||
|
CLIPImageProcessor,
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTextModelWithProjection,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPVisionModelWithProjection,
|
||||||
|
)
|
||||||
|
|
||||||
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
from ...configuration_utils import FrozenDict
|
||||||
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
from ...guiders import ClassifierFreeGuidance
|
||||||
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
|
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
|
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
||||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
from ...models.lora import adjust_lora_scale_text_encoder
|
from ...models.lora import adjust_lora_scale_text_encoder
|
||||||
|
from ...schedulers import EulerDiscreteScheduler
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
USE_PEFT_BACKEND,
|
USE_PEFT_BACKEND,
|
||||||
logging,
|
logging,
|
||||||
@@ -34,33 +45,20 @@ from ...utils.torch_utils import randn_tensor, unwrap_module
|
|||||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import (
|
||||||
AutoPipelineBlocks,
|
AutoPipelineBlocks,
|
||||||
ModularLoader,
|
|
||||||
PipelineBlock,
|
|
||||||
PipelineState,
|
|
||||||
InputParam,
|
|
||||||
OutputParam,
|
|
||||||
SequentialPipelineBlocks,
|
|
||||||
ComponentSpec,
|
ComponentSpec,
|
||||||
ConfigSpec,
|
ConfigSpec,
|
||||||
|
InputParam,
|
||||||
|
ModularLoader,
|
||||||
|
OutputParam,
|
||||||
|
PipelineBlock,
|
||||||
|
PipelineState,
|
||||||
|
SequentialPipelineBlocks,
|
||||||
)
|
)
|
||||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
from .pipeline_output import (
|
from .pipeline_output import (
|
||||||
StableDiffusionXLPipelineOutput,
|
StableDiffusionXLPipelineOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
CLIPTextModel,
|
|
||||||
CLIPImageProcessor,
|
|
||||||
CLIPTextModelWithProjection,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPVisionModelWithProjection,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ...schedulers import EulerDiscreteScheduler
|
|
||||||
from ...guiders import ClassifierFreeGuidance
|
|
||||||
from ...configuration_utils import FrozenDict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utilities to dynamically load objects from the Hub."""
|
"""Utilities to dynamically load objects from the Hub."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@@ -21,8 +22,9 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, ModuleType, Optional, Union
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, model_info
|
from huggingface_hub import hf_hub_download, model_info
|
||||||
@@ -37,6 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
||||||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
||||||
|
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_diffusers_versions():
|
def get_diffusers_versions():
|
||||||
@@ -154,15 +157,132 @@ def check_imports(filename):
|
|||||||
return get_relative_imports(filename)
|
return get_relative_imports(filename)
|
||||||
|
|
||||||
|
|
||||||
def get_class_in_module(class_name, module_path):
|
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
|
||||||
|
if trust_remote_code is None:
|
||||||
|
if has_local_code:
|
||||||
|
trust_remote_code = False
|
||||||
|
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||||
|
prev_sig_handler = None
|
||||||
|
try:
|
||||||
|
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||||
|
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||||
|
while trust_remote_code is None:
|
||||||
|
answer = input(
|
||||||
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||||
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||||
|
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||||
|
f"Do you wish to run the custom code? [y/N] "
|
||||||
|
)
|
||||||
|
if answer.lower() in ["yes", "y", "1"]:
|
||||||
|
trust_remote_code = True
|
||||||
|
elif answer.lower() in ["no", "n", "0", ""]:
|
||||||
|
trust_remote_code = False
|
||||||
|
signal.alarm(0)
|
||||||
|
except Exception:
|
||||||
|
# OS which does not support signal.SIGALRM
|
||||||
|
raise ValueError(
|
||||||
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||||
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||||
|
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if prev_sig_handler is not None:
|
||||||
|
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||||
|
signal.alarm(0)
|
||||||
|
elif has_remote_code:
|
||||||
|
# For the CI which puts the timeout at 0
|
||||||
|
_raise_timeout_error(None, None)
|
||||||
|
|
||||||
|
if has_remote_code and not has_local_code and not trust_remote_code:
|
||||||
|
raise ValueError(
|
||||||
|
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||||
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||||
|
" set the option `trust_remote_code=True` to remove this error."
|
||||||
|
)
|
||||||
|
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_in_modular_module(
|
||||||
|
class_name: str,
|
||||||
|
module_path: Union[str, os.PathLike],
|
||||||
|
*,
|
||||||
|
force_reload: bool = False,
|
||||||
|
) -> type:
|
||||||
|
"""
|
||||||
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_name (`str`): The name of the class to import.
|
||||||
|
module_path (`str` or `os.PathLike`): The path to the module to import.
|
||||||
|
force_reload (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
|
||||||
|
Otherwise, the module is only reloaded if the file has changed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`typing.Type`: The class looked for.
|
||||||
|
"""
|
||||||
|
name = os.path.normpath(module_path)
|
||||||
|
if name.endswith(".py"):
|
||||||
|
name = name[:-3]
|
||||||
|
name = name.replace(os.path.sep, ".")
|
||||||
|
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||||
|
with _HF_REMOTE_CODE_LOCK:
|
||||||
|
if force_reload:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||||
|
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||||
|
|
||||||
|
# Hash the module file and all its relative imports to check if we need to reload it
|
||||||
|
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
|
||||||
|
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
|
||||||
|
|
||||||
|
module: ModuleType
|
||||||
|
if cached_module is None:
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
# insert it into sys.modules before any loading begins
|
||||||
|
sys.modules[name] = module
|
||||||
|
else:
|
||||||
|
module = cached_module
|
||||||
|
# reload in both cases, unless the module is already imported and the hash hits
|
||||||
|
if getattr(module, "__transformers_module_hash__", "") != module_hash:
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
module.__transformers_module_hash__ = module_hash
|
||||||
|
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_in_module(class_name, module_path, force_reload=False):
|
||||||
"""
|
"""
|
||||||
Import a module on the cache directory for modules and extract a class from it.
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
"""
|
"""
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
name = os.path.normpath(module_path)
|
||||||
module = importlib.import_module(module_path)
|
if name.endswith(".py"):
|
||||||
|
name = name[:-3]
|
||||||
|
name = name.replace(os.path.sep, ".")
|
||||||
|
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||||
|
|
||||||
|
with _HF_REMOTE_CODE_LOCK:
|
||||||
|
if force_reload:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||||
|
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||||
|
|
||||||
|
module: ModuleType
|
||||||
|
if cached_module is None:
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
# insert it into sys.modules before any loading begins
|
||||||
|
sys.modules[name] = module
|
||||||
|
else:
|
||||||
|
module = cached_module
|
||||||
|
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
return find_pipeline_class(module)
|
return find_pipeline_class(module)
|
||||||
|
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
@@ -203,6 +323,7 @@ def get_cached_module_file(
|
|||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
is_modular: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
||||||
@@ -257,7 +378,7 @@ def get_cached_module_file(
|
|||||||
if os.path.isfile(module_file_or_url):
|
if os.path.isfile(module_file_or_url):
|
||||||
resolved_module_file = module_file_or_url
|
resolved_module_file = module_file_or_url
|
||||||
submodule = "local"
|
submodule = "local"
|
||||||
elif pretrained_model_name_or_path.count("/") == 0:
|
elif pretrained_model_name_or_path.count("/") == 0 and not is_modular:
|
||||||
available_versions = get_diffusers_versions()
|
available_versions = get_diffusers_versions()
|
||||||
# cut ".dev0"
|
# cut ".dev0"
|
||||||
latest_version = "v" + ".".join(__version__.split(".")[:3])
|
latest_version = "v" + ".".join(__version__.split(".")[:3])
|
||||||
@@ -297,6 +418,24 @@ def get_cached_module_file(
|
|||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
elif is_modular:
|
||||||
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
resolved_module_file = hf_hub_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
module_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||||
|
except EnvironmentError:
|
||||||
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
@@ -381,6 +520,7 @@ def get_class_from_dynamic_module(
|
|||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
is_modular: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -453,5 +593,7 @@ def get_class_from_dynamic_module(
|
|||||||
token=token,
|
token=token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
is_modular=is_modular,
|
||||||
)
|
)
|
||||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
__import__("ipdb").set_trace()
|
||||||
|
return get_class_in_module(class_name, final_module)
|
||||||
|
|||||||
Reference in New Issue
Block a user