Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d34c4e8caf | |||
| b46b7c8b31 | |||
| fc9168f429 | |||
| 31a31ca1c5 | |||
| 8423652b35 | |||
| de631947cc | |||
| 58e9565719 | |||
| cb6d5fed19 | |||
| f16e9c7807 | |||
| 87f63d424a | |||
| 29de29f02c | |||
| 72e1b74638 | |||
| 3471f2fb75 | |||
| d136ae36c8 | |||
| 1b89ac144c | |||
| eb9415031a | |||
| de6ab6b49d | |||
| 4968edc5dc | |||
| 808dff09cb | |||
| 61dac3bbe4 | |||
| 73ab5725c2 | |||
| 163341d3dd | |||
| d0fbf745e6 | |||
| 27c1158b23 | |||
| 96ce6744fe | |||
| 8ad14a52cb | |||
| a7fb2d2a22 | |||
| a0deefb606 | |||
| e2491af650 | |||
| 506a8ea09c | |||
| 58358c2d00 | |||
| 5cde77f915 | |||
| 522e827625 | |||
| 144eae4e0b | |||
| 796453cad1 | |||
| 153ae34ff6 | |||
| 0acb5e1460 | |||
| 462429b687 | |||
| cf01aaeb49 | |||
| 2017ae5624 | |||
| 2b361a2413 | |||
| c677d528e4 | |||
| 0f0618ff2b | |||
| d89631fc50 | |||
| 16b6583fa8 | |||
| f552773572 | |||
| dc4dbfe107 | |||
| 43ac1ff7e7 | |||
| efd70b7838 | |||
| 7ca860c24b | |||
| 7b86fcea31 | |||
| c8b5d56412 |
@@ -39,6 +39,7 @@ _import_structure = {
|
|||||||
"loaders": ["FromOriginalModelMixin"],
|
"loaders": ["FromOriginalModelMixin"],
|
||||||
"models": [],
|
"models": [],
|
||||||
"pipelines": [],
|
"pipelines": [],
|
||||||
|
"modular_pipelines": [],
|
||||||
"quantizers.quantization_config": [],
|
"quantizers.quantization_config": [],
|
||||||
"schedulers": [],
|
"schedulers": [],
|
||||||
"utils": [
|
"utils": [
|
||||||
@@ -254,13 +255,21 @@ else:
|
|||||||
"KarrasVePipeline",
|
"KarrasVePipeline",
|
||||||
"LDMPipeline",
|
"LDMPipeline",
|
||||||
"LDMSuperResolutionPipeline",
|
"LDMSuperResolutionPipeline",
|
||||||
"ModularLoader",
|
|
||||||
"PNDMPipeline",
|
"PNDMPipeline",
|
||||||
"RePaintPipeline",
|
"RePaintPipeline",
|
||||||
"ScoreSdeVePipeline",
|
"ScoreSdeVePipeline",
|
||||||
"StableDiffusionMixin",
|
"StableDiffusionMixin",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["modular_pipelines"].extend(
|
||||||
|
[
|
||||||
|
"ModularLoader",
|
||||||
|
"ModularPipeline",
|
||||||
|
"ModularPipelineBlocks",
|
||||||
|
"ComponentSpec",
|
||||||
|
"ComponentsManager",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["quantizers"] = ["DiffusersQuantizer"]
|
_import_structure["quantizers"] = ["DiffusersQuantizer"]
|
||||||
_import_structure["schedulers"].extend(
|
_import_structure["schedulers"].extend(
|
||||||
[
|
[
|
||||||
@@ -509,12 +518,10 @@ else:
|
|||||||
"StableDiffusionXLImg2ImgPipeline",
|
"StableDiffusionXLImg2ImgPipeline",
|
||||||
"StableDiffusionXLInpaintPipeline",
|
"StableDiffusionXLInpaintPipeline",
|
||||||
"StableDiffusionXLInstructPix2PixPipeline",
|
"StableDiffusionXLInstructPix2PixPipeline",
|
||||||
"StableDiffusionXLModularLoader",
|
|
||||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||||
"StableDiffusionXLPAGInpaintPipeline",
|
"StableDiffusionXLPAGInpaintPipeline",
|
||||||
"StableDiffusionXLPAGPipeline",
|
"StableDiffusionXLPAGPipeline",
|
||||||
"StableDiffusionXLPipeline",
|
"StableDiffusionXLPipeline",
|
||||||
"StableDiffusionXLAutoPipeline",
|
|
||||||
"StableUnCLIPImg2ImgPipeline",
|
"StableUnCLIPImg2ImgPipeline",
|
||||||
"StableUnCLIPPipeline",
|
"StableUnCLIPPipeline",
|
||||||
"StableVideoDiffusionPipeline",
|
"StableVideoDiffusionPipeline",
|
||||||
@@ -541,6 +548,24 @@ else:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_torch_available() and is_transformers_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
|
||||||
|
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
_import_structure["modular_pipelines"].extend(
|
||||||
|
[
|
||||||
|
"StableDiffusionXLAutoPipeline",
|
||||||
|
"StableDiffusionXLModularLoader",
|
||||||
|
]
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@@ -864,12 +889,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
KarrasVePipeline,
|
KarrasVePipeline,
|
||||||
LDMPipeline,
|
LDMPipeline,
|
||||||
LDMSuperResolutionPipeline,
|
LDMSuperResolutionPipeline,
|
||||||
ModularLoader,
|
|
||||||
PNDMPipeline,
|
PNDMPipeline,
|
||||||
RePaintPipeline,
|
RePaintPipeline,
|
||||||
ScoreSdeVePipeline,
|
ScoreSdeVePipeline,
|
||||||
StableDiffusionMixin,
|
StableDiffusionMixin,
|
||||||
)
|
)
|
||||||
|
from .modular_pipelines import (
|
||||||
|
ModularLoader,
|
||||||
|
ModularPipeline,
|
||||||
|
ModularPipelineBlocks,
|
||||||
|
ComponentSpec,
|
||||||
|
ComponentsManager,
|
||||||
|
)
|
||||||
from .quantizers import DiffusersQuantizer
|
from .quantizers import DiffusersQuantizer
|
||||||
from .schedulers import (
|
from .schedulers import (
|
||||||
AmusedScheduler,
|
AmusedScheduler,
|
||||||
@@ -1097,12 +1128,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLImg2ImgPipeline,
|
StableDiffusionXLImg2ImgPipeline,
|
||||||
StableDiffusionXLInpaintPipeline,
|
StableDiffusionXLInpaintPipeline,
|
||||||
StableDiffusionXLInstructPix2PixPipeline,
|
StableDiffusionXLInstructPix2PixPipeline,
|
||||||
StableDiffusionXLModularLoader,
|
|
||||||
StableDiffusionXLPAGImg2ImgPipeline,
|
StableDiffusionXLPAGImg2ImgPipeline,
|
||||||
StableDiffusionXLPAGInpaintPipeline,
|
StableDiffusionXLPAGInpaintPipeline,
|
||||||
StableDiffusionXLPAGPipeline,
|
StableDiffusionXLPAGPipeline,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
StableUnCLIPImg2ImgPipeline,
|
StableUnCLIPImg2ImgPipeline,
|
||||||
StableUnCLIPPipeline,
|
StableUnCLIPPipeline,
|
||||||
StableVideoDiffusionPipeline,
|
StableVideoDiffusionPipeline,
|
||||||
@@ -1127,7 +1156,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
WuerstchenDecoderPipeline,
|
WuerstchenDecoderPipeline,
|
||||||
WuerstchenPriorPipeline,
|
WuerstchenPriorPipeline,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
if not (is_torch_available() and is_transformers_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .modular_pipelines import (
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
|
StableDiffusionXLModularLoader,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|||||||
@@ -13,14 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveProjectedGuidance(BaseGuidance):
|
class AdaptiveProjectedGuidance(BaseGuidance):
|
||||||
@@ -73,14 +73,18 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
|||||||
self.use_original_formulation = use_original_formulation
|
self.use_original_formulation = use_original_formulation
|
||||||
self.momentum_buffer = None
|
self.momentum_buffer = None
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
if self._step == 0:
|
if self._step == 0:
|
||||||
if self.adaptive_projected_guidance_momentum is not None:
|
if self.adaptive_projected_guidance_momentum is not None:
|
||||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ from ..hooks.layer_skip import _apply_layer_skip_hook
|
|||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class AutoGuidance(BaseGuidance):
|
class AutoGuidance(BaseGuidance):
|
||||||
@@ -120,11 +120,15 @@ class AutoGuidance(BaseGuidance):
|
|||||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||||
registry.remove_hook(name, recurse=True)
|
registry.remove_hook(name, recurse=True)
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -13,14 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class ClassifierFreeGuidance(BaseGuidance):
|
class ClassifierFreeGuidance(BaseGuidance):
|
||||||
@@ -75,11 +75,15 @@ class ClassifierFreeGuidance(BaseGuidance):
|
|||||||
self.guidance_rescale = guidance_rescale
|
self.guidance_rescale = guidance_rescale
|
||||||
self.use_original_formulation = use_original_formulation
|
self.use_original_formulation = use_original_formulation
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -13,14 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||||
@@ -73,11 +73,15 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
|||||||
self.guidance_rescale = guidance_rescale
|
self.guidance_rescale = guidance_rescale
|
||||||
self.use_original_formulation = use_original_formulation
|
self.use_original_formulation = use_original_formulation
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from ..utils import get_logger
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -171,10 +171,10 @@ class BaseGuidance:
|
|||||||
Returns:
|
Returns:
|
||||||
`BlockState`: The prepared batch of data.
|
`BlockState`: The prepared batch of data.
|
||||||
"""
|
"""
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
if input_fields is None:
|
if input_fields is None:
|
||||||
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
|
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
|
||||||
data_batch = {}
|
data_batch = {}
|
||||||
for key, value in input_fields.items():
|
for key, value in input_fields.items():
|
||||||
try:
|
try:
|
||||||
@@ -186,7 +186,7 @@ class BaseGuidance:
|
|||||||
# We've already checked that value is a string or a tuple of strings with length 2
|
# We've already checked that value is a string or a tuple of strings with length 2
|
||||||
pass
|
pass
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
|
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||||
data_batch[cls._identifier_key] = identifier
|
data_batch[cls._identifier_key] = identifier
|
||||||
return BlockState(**data_batch)
|
return BlockState(**data_batch)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ from ..hooks.layer_skip import _apply_layer_skip_hook
|
|||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidance(BaseGuidance):
|
class SkipLayerGuidance(BaseGuidance):
|
||||||
@@ -156,7 +156,11 @@ class SkipLayerGuidance(BaseGuidance):
|
|||||||
for hook_name in self._skip_layer_hook_names:
|
for hook_name in self._skip_layer_hook_names:
|
||||||
registry.remove_hook(hook_name, recurse=True)
|
registry.remove_hook(hook_name, recurse=True)
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
if self.num_conditions == 1:
|
if self.num_conditions == 1:
|
||||||
tuple_indices = [0]
|
tuple_indices = [0]
|
||||||
input_predictions = ["pred_cond"]
|
input_predictions = ["pred_cond"]
|
||||||
@@ -168,7 +172,7 @@ class SkipLayerGuidance(BaseGuidance):
|
|||||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Union, TYPE_CHECKING
|
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig,
|
|||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class SmoothedEnergyGuidance(BaseGuidance):
|
class SmoothedEnergyGuidance(BaseGuidance):
|
||||||
@@ -149,7 +149,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
|||||||
for hook_name in self._seg_layer_hook_names:
|
for hook_name in self._seg_layer_hook_names:
|
||||||
registry.remove_hook(hook_name, recurse=True)
|
registry.remove_hook(hook_name, recurse=True)
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
if self.num_conditions == 1:
|
if self.num_conditions == 1:
|
||||||
tuple_indices = [0]
|
tuple_indices = [0]
|
||||||
input_predictions = ["pred_cond"]
|
input_predictions = ["pred_cond"]
|
||||||
@@ -161,7 +165,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
|||||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -13,14 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..pipelines.modular_pipeline import BlockState
|
from ..modular_pipelines.modular_pipeline import BlockState
|
||||||
|
|
||||||
|
|
||||||
class TangentialClassifierFreeGuidance(BaseGuidance):
|
class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||||
@@ -62,11 +62,15 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
|||||||
self.guidance_rescale = guidance_rescale
|
self.guidance_rescale = guidance_rescale
|
||||||
self.use_original_formulation = use_original_formulation
|
self.use_original_formulation = use_original_formulation
|
||||||
|
|
||||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||||
|
|
||||||
|
if input_fields is None:
|
||||||
|
input_fields = self._input_fields
|
||||||
|
|
||||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||||
data_batches = []
|
data_batches = []
|
||||||
for i in range(self.num_conditions):
|
for i in range(self.num_conditions):
|
||||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||||
data_batches.append(data_batch)
|
data_batches.append(data_batch)
|
||||||
return data_batches
|
return data_batches
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||||
|
|
||||||
|
|
||||||
|
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
|
||||||
|
# either remove or make it serializable
|
||||||
@dataclass
|
@dataclass
|
||||||
class LayerSkipConfig:
|
class LayerSkipConfig:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# These modules contain pipelines from multiple libraries/frameworks
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ..utils import dummy_pt_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["modular_pipeline"] = [
|
||||||
|
"ModularPipelineBlocks",
|
||||||
|
"ModularPipeline",
|
||||||
|
"PipelineBlock",
|
||||||
|
"AutoPipelineBlocks",
|
||||||
|
"SequentialPipelineBlocks",
|
||||||
|
"LoopSequentialPipelineBlocks",
|
||||||
|
"ModularLoader",
|
||||||
|
"PipelineState",
|
||||||
|
"BlockState",
|
||||||
|
]
|
||||||
|
_import_structure["modular_pipeline_utils"] = [
|
||||||
|
"ComponentSpec",
|
||||||
|
"ConfigSpec",
|
||||||
|
"InputParam",
|
||||||
|
"OutputParam",
|
||||||
|
]
|
||||||
|
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"]
|
||||||
|
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ..utils.dummy_pt_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .modular_pipeline import (
|
||||||
|
AutoPipelineBlocks,
|
||||||
|
BlockState,
|
||||||
|
LoopSequentialPipelineBlocks,
|
||||||
|
ModularLoader,
|
||||||
|
ModularPipelineBlocks,
|
||||||
|
ModularPipeline,
|
||||||
|
PipelineBlock,
|
||||||
|
PipelineState,
|
||||||
|
SequentialPipelineBlocks,
|
||||||
|
)
|
||||||
|
from .modular_pipeline_utils import (
|
||||||
|
ComponentSpec,
|
||||||
|
ConfigSpec,
|
||||||
|
InputParam,
|
||||||
|
OutputParam,
|
||||||
|
)
|
||||||
|
from .stable_diffusion_xl import (
|
||||||
|
StableDiffusionXLAutoPipeline,
|
||||||
|
StableDiffusionXLModularLoader,
|
||||||
|
)
|
||||||
|
from .components_manager import ComponentsManager
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
+142
-70
@@ -29,6 +29,9 @@ from ..models.modeling_utils import ModelMixin
|
|||||||
from .modular_pipeline_utils import ComponentSpec
|
from .modular_pipeline_utils import ComponentSpec
|
||||||
|
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
|
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
|
||||||
from accelerate.state import PartialState
|
from accelerate.state import PartialState
|
||||||
@@ -231,8 +234,6 @@ class AutoOffloadStrategy:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
from .modular_pipeline_utils import ComponentSpec
|
|
||||||
import uuid
|
|
||||||
class ComponentsManager:
|
class ComponentsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.components = OrderedDict()
|
self.components = OrderedDict()
|
||||||
@@ -242,78 +243,122 @@ class ComponentsManager:
|
|||||||
self._auto_offload_enabled = False
|
self._auto_offload_enabled = False
|
||||||
|
|
||||||
|
|
||||||
def _get_by_collection(self, collection: str):
|
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
|
||||||
"""
|
"""
|
||||||
Select components by collection name.
|
Lookup component_ids by name, collection, or load_id.
|
||||||
"""
|
"""
|
||||||
selected_components = {}
|
if components is None:
|
||||||
if collection in self.collections:
|
components = self.components
|
||||||
component_ids = self.collections[collection]
|
|
||||||
for component_id in component_ids:
|
|
||||||
selected_components[component_id] = self.components[component_id]
|
|
||||||
return selected_components
|
|
||||||
|
|
||||||
|
if name:
|
||||||
|
ids_by_name = set()
|
||||||
|
for component_id, component in components.items():
|
||||||
|
comp_name = self._id_to_name(component_id)
|
||||||
|
if comp_name == name:
|
||||||
|
ids_by_name.add(component_id)
|
||||||
|
else:
|
||||||
|
ids_by_name = set(components.keys())
|
||||||
|
if collection:
|
||||||
|
ids_by_collection = set()
|
||||||
|
for component_id, component in components.items():
|
||||||
|
if component_id in self.collections[collection]:
|
||||||
|
ids_by_collection.add(component_id)
|
||||||
|
else:
|
||||||
|
ids_by_collection = set(components.keys())
|
||||||
|
if load_id:
|
||||||
|
ids_by_load_id = set()
|
||||||
|
for name, component in components.items():
|
||||||
|
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
||||||
|
ids_by_load_id.add(name)
|
||||||
|
else:
|
||||||
|
ids_by_load_id = set(components.keys())
|
||||||
|
|
||||||
def _get_by_load_id(self, load_id: str):
|
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
|
||||||
"""
|
return ids
|
||||||
Select components by its load_id.
|
|
||||||
"""
|
|
||||||
selected_components = {}
|
|
||||||
for name, component in self.components.items():
|
|
||||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
|
||||||
selected_components[name] = component
|
|
||||||
return selected_components
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _id_to_name(component_id: str):
|
||||||
|
return "_".join(component_id.split("_")[:-1])
|
||||||
|
|
||||||
def add(self, name, component, collection: Optional[str] = None):
|
def add(self, name, component, collection: Optional[str] = None):
|
||||||
|
|
||||||
for comp_id, comp in self.components.items():
|
|
||||||
if comp == component:
|
|
||||||
logger.warning(f"Component '{name}' already exists in ComponentsManager")
|
|
||||||
return comp_id
|
|
||||||
|
|
||||||
component_id = f"{name}_{uuid.uuid4()}"
|
component_id = f"{name}_{uuid.uuid4()}"
|
||||||
|
|
||||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
# check for duplicated components
|
||||||
components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id)
|
for comp_id, comp in self.components.items():
|
||||||
if components_with_same_load_id:
|
if comp == component:
|
||||||
existing = ", ".join(components_with_same_load_id.keys())
|
comp_name = self._id_to_name(comp_id)
|
||||||
logger.warning(
|
if comp_name == name:
|
||||||
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
logger.warning(
|
||||||
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
|
f"component '{name}' already exists as '{comp_id}'"
|
||||||
)
|
)
|
||||||
|
component_id = comp_id
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
|
||||||
|
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# check for duplicated load_id and warn (we do not delete for you)
|
||||||
|
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||||
|
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
|
||||||
|
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
|
||||||
|
|
||||||
|
if components_with_same_load_id:
|
||||||
|
existing = ", ".join(components_with_same_load_id)
|
||||||
|
logger.warning(
|
||||||
|
f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||||
|
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
|
||||||
|
)
|
||||||
|
|
||||||
# add component to components manager
|
# add component to components manager
|
||||||
self.components[component_id] = component
|
self.components[component_id] = component
|
||||||
self.added_time[component_id] = time.time()
|
self.added_time[component_id] = time.time()
|
||||||
|
|
||||||
if collection:
|
if collection:
|
||||||
if collection not in self.collections:
|
if collection not in self.collections:
|
||||||
self.collections[collection] = set()
|
self.collections[collection] = set()
|
||||||
self.collections[collection].add(component_id)
|
if not component_id in self.collections[collection]:
|
||||||
|
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
|
||||||
|
for comp_id in comp_ids_in_collection:
|
||||||
|
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
|
||||||
|
self.remove(comp_id)
|
||||||
|
self.collections[collection].add(component_id)
|
||||||
|
logger.info(f"Added component '{name}' in collection '{collection}': {component_id}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Added component '{name}' as '{component_id}'")
|
||||||
|
|
||||||
if self._auto_offload_enabled:
|
if self._auto_offload_enabled:
|
||||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||||
|
|
||||||
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
|
|
||||||
return component_id
|
return component_id
|
||||||
|
|
||||||
|
|
||||||
def remove(self, name: Union[str, List[str]]):
|
def remove(self, component_id: str = None):
|
||||||
|
|
||||||
if name not in self.components:
|
if component_id not in self.components:
|
||||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.components.pop(name)
|
component = self.components.pop(component_id)
|
||||||
self.added_time.pop(name)
|
self.added_time.pop(component_id)
|
||||||
|
|
||||||
for collection in self.collections:
|
for collection in self.collections:
|
||||||
if name in self.collections[collection]:
|
if component_id in self.collections[collection]:
|
||||||
self.collections[collection].remove(name)
|
self.collections[collection].remove(component_id)
|
||||||
|
|
||||||
if self._auto_offload_enabled:
|
if self._auto_offload_enabled:
|
||||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||||
|
else:
|
||||||
|
if isinstance(component, torch.nn.Module):
|
||||||
|
component.to("cpu")
|
||||||
|
del component
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
|
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
|
||||||
as_name_component_tuples: bool = False):
|
as_name_component_tuples: bool = False):
|
||||||
@@ -342,16 +387,8 @@ class ComponentsManager:
|
|||||||
or list of (base_name, component) tuples if as_name_component_tuples=True
|
or list of (base_name, component) tuples if as_name_component_tuples=True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if collection:
|
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
|
||||||
if collection not in self.collections:
|
components = {k: self.components[k] for k in selected_ids}
|
||||||
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
|
||||||
return [] if as_name_component_tuples else {}
|
|
||||||
components = self._get_by_collection(collection)
|
|
||||||
else:
|
|
||||||
components = self.components
|
|
||||||
|
|
||||||
if load_id:
|
|
||||||
components = self._get_by_load_id(load_id)
|
|
||||||
|
|
||||||
# Helper to extract base name from component_id
|
# Helper to extract base name from component_id
|
||||||
def get_base_name(component_id):
|
def get_base_name(component_id):
|
||||||
@@ -541,11 +578,11 @@ class ComponentsManager:
|
|||||||
self._auto_offload_enabled = False
|
self._auto_offload_enabled = False
|
||||||
|
|
||||||
# YiYi TODO: add quantization info
|
# YiYi TODO: add quantization info
|
||||||
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
||||||
"""Get comprehensive information about a component.
|
"""Get comprehensive information about a component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the component to get info for
|
component_id: Name of the component to get info for
|
||||||
fields: Optional field(s) to return. Can be a string for single field or list of fields.
|
fields: Optional field(s) to return. Can be a string for single field or list of fields.
|
||||||
If None, returns all fields.
|
If None, returns all fields.
|
||||||
|
|
||||||
@@ -554,16 +591,16 @@ class ComponentsManager:
|
|||||||
If fields is specified, returns only those fields.
|
If fields is specified, returns only those fields.
|
||||||
If a single field is requested as string, returns just that field's value.
|
If a single field is requested as string, returns just that field's value.
|
||||||
"""
|
"""
|
||||||
if name not in self.components:
|
if component_id not in self.components:
|
||||||
raise ValueError(f"Component '{name}' not found in ComponentsManager")
|
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||||
|
|
||||||
component = self.components[name]
|
component = self.components[component_id]
|
||||||
|
|
||||||
# Build complete info dict first
|
# Build complete info dict first
|
||||||
info = {
|
info = {
|
||||||
"model_id": name,
|
"model_id": component_id,
|
||||||
"added_time": self.added_time[name],
|
"added_time": self.added_time[component_id],
|
||||||
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
|
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Additional info for torch.nn.Module components
|
# Additional info for torch.nn.Module components
|
||||||
@@ -649,11 +686,19 @@ class ComponentsManager:
|
|||||||
]
|
]
|
||||||
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
||||||
|
|
||||||
# Collection names
|
# Get all collections for each component
|
||||||
collection_names = [
|
component_collections = {}
|
||||||
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
|
for name in self.components.keys():
|
||||||
for name in self.components.keys()
|
component_collections[name] = []
|
||||||
]
|
for coll, comps in self.collections.items():
|
||||||
|
if name in comps:
|
||||||
|
component_collections[name].append(coll)
|
||||||
|
if not component_collections[name]:
|
||||||
|
component_collections[name] = ["N/A"]
|
||||||
|
|
||||||
|
# Find the maximum collection name length
|
||||||
|
all_collections = [coll for colls in component_collections.values() for coll in colls]
|
||||||
|
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
|
||||||
|
|
||||||
col_widths = {
|
col_widths = {
|
||||||
"name": max(15, max(len(name) for name in simple_names)),
|
"name": max(15, max(len(name) for name in simple_names)),
|
||||||
@@ -662,7 +707,7 @@ class ComponentsManager:
|
|||||||
"dtype": 15,
|
"dtype": 15,
|
||||||
"size": 10,
|
"size": 10,
|
||||||
"load_id": max_load_id_len,
|
"load_id": max_load_id_len,
|
||||||
"collection": max(10, max(len(str(c)) for c in collection_names))
|
"collection": max_collection_len
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create the header lines
|
# Create the header lines
|
||||||
@@ -691,11 +736,21 @@ class ComponentsManager:
|
|||||||
device_str = format_device(component, info)
|
device_str = format_device(component, info)
|
||||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||||
load_id = get_load_id(component)
|
load_id = get_load_id(component)
|
||||||
collection = info["collection"] or "N/A"
|
|
||||||
|
# Print first collection on the main line
|
||||||
|
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
|
||||||
|
|
||||||
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||||
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
||||||
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
|
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
|
||||||
|
|
||||||
|
# Print additional collections on separate lines if they exist
|
||||||
|
for i in range(1, len(component_collections[name])):
|
||||||
|
collection = component_collections[name][i]
|
||||||
|
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | "
|
||||||
|
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
|
||||||
|
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
|
||||||
|
|
||||||
output += dash_line
|
output += dash_line
|
||||||
|
|
||||||
# Other components section
|
# Other components section
|
||||||
@@ -711,9 +766,17 @@ class ComponentsManager:
|
|||||||
for name, component in others.items():
|
for name, component in others.items():
|
||||||
info = self.get_model_info(name)
|
info = self.get_model_info(name)
|
||||||
simple_name = get_simple_name(name)
|
simple_name = get_simple_name(name)
|
||||||
collection = info["collection"] or "N/A"
|
|
||||||
|
|
||||||
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
|
# Print first collection on the main line
|
||||||
|
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
|
||||||
|
|
||||||
|
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
|
||||||
|
|
||||||
|
# Print additional collections on separate lines if they exist
|
||||||
|
for i in range(1, len(component_collections[name])):
|
||||||
|
collection = component_collections[name][i]
|
||||||
|
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n"
|
||||||
|
|
||||||
output += dash_line
|
output += dash_line
|
||||||
|
|
||||||
# Add additional component info
|
# Add additional component info
|
||||||
@@ -775,7 +838,7 @@ class ComponentsManager:
|
|||||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
|
def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
Get a single component by name. Raises an error if multiple components match or none are found.
|
Get a single component by name. Raises an error if multiple components match or none are found.
|
||||||
|
|
||||||
@@ -790,6 +853,15 @@ class ComponentsManager:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If no components match or multiple components match
|
ValueError: If no components match or multiple components match
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# if component_id is provided, return the component
|
||||||
|
if component_id is not None and (name is not None or collection is not None or load_id is not None):
|
||||||
|
raise ValueError(" if component_id is provided, name, collection, and load_id must be None")
|
||||||
|
elif component_id is not None:
|
||||||
|
if component_id not in self.components:
|
||||||
|
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||||
|
return self.components[component_id]
|
||||||
|
|
||||||
results = self.get(name, collection, load_id)
|
results = self.get(name, collection, load_id)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
+1118
-235
File diff suppressed because it is too large
Load Diff
+68
-44
@@ -19,11 +19,30 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
|
|||||||
|
|
||||||
from ..utils.import_utils import is_torch_available
|
from ..utils.import_utils import is_torch_available
|
||||||
from ..configuration_utils import FrozenDict, ConfigMixin
|
from ..configuration_utils import FrozenDict, ConfigMixin
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class InsertableOrderedDict(OrderedDict):
|
||||||
|
def insert(self, key, value, index):
|
||||||
|
items = list(self.items())
|
||||||
|
|
||||||
|
# Remove key if it already exists to avoid duplicates
|
||||||
|
items = [(k, v) for k, v in items if k != key]
|
||||||
|
|
||||||
|
# Insert at the specified index
|
||||||
|
items.insert(index, (key, value))
|
||||||
|
|
||||||
|
# Clear and update self
|
||||||
|
self.clear()
|
||||||
|
self.update(items)
|
||||||
|
|
||||||
|
# Return self for method chaining
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# YiYi TODO:
|
# YiYi TODO:
|
||||||
# 1. validate the dataclass fields
|
# 1. validate the dataclass fields
|
||||||
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
|
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
|
||||||
@@ -71,34 +90,31 @@ class ComponentSpec:
|
|||||||
self.default_creation_method == other.default_creation_method)
|
self.default_creation_method == other.default_creation_method)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
|
def from_component(cls, name: str, component: Any) -> Any:
|
||||||
"""Create a ComponentSpec from a Component created by `create` method."""
|
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
|
||||||
|
|
||||||
if not hasattr(component, "_diffusers_load_id"):
|
if not hasattr(component, "_diffusers_load_id"):
|
||||||
raise ValueError("Component is not created by `create` method")
|
raise ValueError("Component is not created by `create` or `load` method")
|
||||||
|
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
|
||||||
|
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
|
||||||
|
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
|
||||||
|
raise ValueError(
|
||||||
|
"We currently only support creating ComponentSpec from a component with "
|
||||||
|
"created with `ComponentSpec.load` method"
|
||||||
|
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
|
||||||
|
)
|
||||||
|
|
||||||
type_hint = component.__class__
|
type_hint = component.__class__
|
||||||
|
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
|
||||||
|
|
||||||
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
|
if isinstance(component, ConfigMixin):
|
||||||
config = component.config
|
config = component.config
|
||||||
else:
|
else:
|
||||||
config = None
|
config = None
|
||||||
|
|
||||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||||
|
|
||||||
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
|
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
|
|
||||||
"""Create a ComponentSpec from a load_id string."""
|
|
||||||
if load_id == "null":
|
|
||||||
raise ValueError("Cannot create ComponentSpec from null load_id")
|
|
||||||
|
|
||||||
# Decode the load_id into a dictionary of loading fields
|
|
||||||
load_fields = cls.decode_load_id(load_id)
|
|
||||||
|
|
||||||
# Create a new ComponentSpec instance with the decoded fields
|
|
||||||
return cls(name=name, **load_fields)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def loading_fields(cls) -> List[str]:
|
def loading_fields(cls) -> List[str]:
|
||||||
@@ -137,7 +153,7 @@ class ComponentSpec:
|
|||||||
"revision": "revision"
|
"revision": "revision"
|
||||||
}
|
}
|
||||||
If a segment value is "null", it's replaced with None.
|
If a segment value is "null", it's replaced with None.
|
||||||
Returns None if load_id is "null" (indicating component not loaded from pretrained).
|
Returns None if load_id is "null" (indicating component not created with `load` method).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get all loading fields in order
|
# Get all loading fields in order
|
||||||
@@ -158,20 +174,12 @@ class ComponentSpec:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# YiYi TODO: add validator
|
|
||||||
def create(self, **kwargs) -> Any:
|
|
||||||
"""Create the component using the preferred creation method."""
|
|
||||||
|
|
||||||
# from_pretrained creation
|
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
|
||||||
if self.default_creation_method == "from_pretrained":
|
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
|
||||||
return self.create_from_pretrained(**kwargs)
|
# the config info is lost in the process
|
||||||
elif self.default_creation_method == "from_config":
|
# remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method
|
||||||
# from_config creation
|
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
||||||
return self.create_from_config(**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
|
|
||||||
|
|
||||||
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
|
||||||
"""Create component using from_config with config."""
|
"""Create component using from_config with config."""
|
||||||
|
|
||||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||||
@@ -201,34 +209,35 @@ class ComponentSpec:
|
|||||||
return component
|
return component
|
||||||
|
|
||||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||||
def create_from_pretrained(self, **kwargs) -> Any:
|
def load(self, **kwargs) -> Any:
|
||||||
"""Create component using from_pretrained."""
|
"""Load component using from_pretrained."""
|
||||||
|
|
||||||
|
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
||||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||||
|
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||||
repo = load_kwargs.pop("repo", None)
|
repo = load_kwargs.pop("repo", None)
|
||||||
if repo is None:
|
if repo is None:
|
||||||
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||||
|
|
||||||
if self.type_hint is None:
|
if self.type_hint is None:
|
||||||
try:
|
try:
|
||||||
from diffusers import AutoModel
|
from diffusers import AutoModel
|
||||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}")
|
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||||
|
# update type_hint if AutoModel load successfully
|
||||||
self.type_hint = component.__class__
|
self.type_hint = component.__class__
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
|
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||||
|
|
||||||
if repo != self.repo:
|
self.repo = repo
|
||||||
self.repo = repo
|
for k, v in load_kwargs.items():
|
||||||
for k, v in passed_loading_kwargs.items():
|
setattr(self, k, v)
|
||||||
if v is not None:
|
|
||||||
setattr(self, k, v)
|
|
||||||
component._diffusers_load_id = self.load_id
|
component._diffusers_load_id = self.load_id
|
||||||
|
|
||||||
return component
|
return component
|
||||||
@@ -241,14 +250,22 @@ class ConfigSpec:
|
|||||||
name: str
|
name: str
|
||||||
default: Any
|
default: Any
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# YiYi Notes: both inputs and intermediates_inputs are InputParam objects
|
||||||
|
# however some fields are not relevant for intermediates_inputs
|
||||||
|
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
|
||||||
|
# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs
|
||||||
|
# -> should we use different class for inputs and intermediates_inputs?
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputParam:
|
class InputParam:
|
||||||
"""Specification for an input parameter."""
|
"""Specification for an input parameter."""
|
||||||
name: str
|
name: str = None
|
||||||
type_hint: Any = None
|
type_hint: Any = None
|
||||||
default: Any = None
|
default: Any = None
|
||||||
required: bool = False
|
required: bool = False
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||||
@@ -260,6 +277,7 @@ class OutputParam:
|
|||||||
name: str
|
name: str
|
||||||
type_hint: Any = None
|
type_hint: Any = None
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||||
@@ -320,7 +338,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
|||||||
if inp.name in required_intermediates_inputs:
|
if inp.name in required_intermediates_inputs:
|
||||||
input_parts.append(f"Required({inp.name})")
|
input_parts.append(f"Required({inp.name})")
|
||||||
else:
|
else:
|
||||||
input_parts.append(inp.name)
|
if inp.name is None and inp.kwargs_type is not None:
|
||||||
|
inp_name = "*_" + inp.kwargs_type
|
||||||
|
else:
|
||||||
|
inp_name = inp.name
|
||||||
|
input_parts.append(inp_name)
|
||||||
|
|
||||||
# Handle modified variables (appear in both inputs and outputs)
|
# Handle modified variables (appear in both inputs and outputs)
|
||||||
inputs_set = {inp.name for inp in intermediates_inputs}
|
inputs_set = {inp.name for inp in intermediates_inputs}
|
||||||
@@ -399,7 +421,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
|||||||
for param in params:
|
for param in params:
|
||||||
# Format parameter name and type
|
# Format parameter name and type
|
||||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||||
param_str = f"{param_indent}{param.name} (`{type_str}`"
|
# YiYi Notes: remove this line if we remove kwargs_type
|
||||||
|
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
|
||||||
|
param_str = f"{param_indent}{name} (`{type_str}`"
|
||||||
|
|
||||||
# Add optional tag and default value if parameter is an InputParam and optional
|
# Add optional tag and default value if parameter is an InputParam and optional
|
||||||
if hasattr(param, "required"):
|
if hasattr(param, "required"):
|
||||||
@@ -0,0 +1,519 @@
|
|||||||
|
from ..configuration_utils import ConfigMixin
|
||||||
|
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
|
||||||
|
from .modular_pipeline_utils import InputParam, OutputParam
|
||||||
|
from ..image_processor import PipelineImageInput
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||||
|
SDXL_INPUTS_SCHEMA = {
|
||||||
|
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
|
||||||
|
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
|
||||||
|
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
|
||||||
|
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
|
||||||
|
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
|
||||||
|
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
|
||||||
|
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
|
||||||
|
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
|
||||||
|
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
|
||||||
|
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||||
|
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||||
|
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
|
||||||
|
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
|
||||||
|
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
|
||||||
|
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
|
||||||
|
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
|
||||||
|
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||||
|
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
|
||||||
|
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
|
||||||
|
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
|
||||||
|
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
|
||||||
|
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
|
||||||
|
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
|
||||||
|
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
|
||||||
|
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
|
||||||
|
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
|
||||||
|
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
|
||||||
|
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
|
||||||
|
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
|
||||||
|
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||||
|
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
|
||||||
|
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
|
||||||
|
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
|
||||||
|
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
|
||||||
|
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
|
||||||
|
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
|
||||||
|
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
|
||||||
|
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
|
||||||
|
}
|
||||||
|
|
||||||
|
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||||
|
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
|
||||||
|
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
|
||||||
|
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
|
||||||
|
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
|
||||||
|
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||||
|
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||||
|
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
|
||||||
|
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
|
||||||
|
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||||
|
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
|
||||||
|
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
|
||||||
|
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
|
||||||
|
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||||
|
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
|
||||||
|
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
|
||||||
|
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
|
||||||
|
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||||
|
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||||
|
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||||
|
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
|
||||||
|
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
|
||||||
|
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
|
||||||
|
}
|
||||||
|
|
||||||
|
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PARAM_MAPS = {
|
||||||
|
"prompt": {
|
||||||
|
"label": "Prompt",
|
||||||
|
"type": "string",
|
||||||
|
"default": "a bear sitting in a chair drinking a milkshake",
|
||||||
|
"display": "textarea",
|
||||||
|
},
|
||||||
|
"negative_prompt": {
|
||||||
|
"label": "Negative Prompt",
|
||||||
|
"type": "string",
|
||||||
|
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||||
|
"display": "textarea",
|
||||||
|
},
|
||||||
|
|
||||||
|
"num_inference_steps": {
|
||||||
|
"label": "Steps",
|
||||||
|
"type": "int",
|
||||||
|
"default": 25,
|
||||||
|
"min": 1,
|
||||||
|
"max": 1000,
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"label": "Seed",
|
||||||
|
"type": "int",
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"display": "random",
|
||||||
|
},
|
||||||
|
"width": {
|
||||||
|
"label": "Width",
|
||||||
|
"type": "int",
|
||||||
|
"display": "text",
|
||||||
|
"default": 1024,
|
||||||
|
"min": 8,
|
||||||
|
"max": 8192,
|
||||||
|
"step": 8,
|
||||||
|
"group": "dimensions",
|
||||||
|
},
|
||||||
|
"height": {
|
||||||
|
"label": "Height",
|
||||||
|
"type": "int",
|
||||||
|
"display": "text",
|
||||||
|
"default": 1024,
|
||||||
|
"min": 8,
|
||||||
|
"max": 8192,
|
||||||
|
"step": 8,
|
||||||
|
"group": "dimensions",
|
||||||
|
},
|
||||||
|
"images": {
|
||||||
|
"label": "Images",
|
||||||
|
"type": "image",
|
||||||
|
"display": "output",
|
||||||
|
},
|
||||||
|
"image": {
|
||||||
|
"label": "Image",
|
||||||
|
"type": "image",
|
||||||
|
"display": "input",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_TYPE_MAPS ={
|
||||||
|
"int": {
|
||||||
|
"type": "int",
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
},
|
||||||
|
"float": {
|
||||||
|
"type": "float",
|
||||||
|
"default": 0.0,
|
||||||
|
"min": 0.0,
|
||||||
|
},
|
||||||
|
"str": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
},
|
||||||
|
"bool": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
"image": {
|
||||||
|
"type": "image",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
||||||
|
DEFAULT_CATEGORY = "Modular Diffusers"
|
||||||
|
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
||||||
|
DEFAULT_PARAMS_GROUPS_KEYS = {
|
||||||
|
"text_encoders": ["text_encoder", "tokenizer"],
|
||||||
|
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
||||||
|
"prompt_embeddings": ["prompt_embeds"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||||
|
"""
|
||||||
|
Get the group name for a given parameter name, if not part of a group, return None
|
||||||
|
e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
||||||
|
"""
|
||||||
|
if name is None:
|
||||||
|
return None
|
||||||
|
for group_name, group_keys in group_params_keys.items():
|
||||||
|
for group_key in group_keys:
|
||||||
|
if group_key in name:
|
||||||
|
return group_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ModularNode(ConfigMixin):
|
||||||
|
|
||||||
|
config_name = "node_config.json"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
trust_remote_code: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
||||||
|
return cls(blocks, **kwargs)
|
||||||
|
|
||||||
|
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||||
|
self.blocks = blocks
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
label = self.blocks.__class__.__name__
|
||||||
|
# blocks param name -> mellon param name
|
||||||
|
self.name_mapping = {}
|
||||||
|
|
||||||
|
input_params = {}
|
||||||
|
# pass or create a default param dict for each input
|
||||||
|
# e.g. for prompt,
|
||||||
|
# prompt = {
|
||||||
|
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
|
||||||
|
# "label": "Prompt",
|
||||||
|
# "type": "string",
|
||||||
|
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||||
|
# "display": "textarea"}
|
||||||
|
# if type is not specified, it'll be a "custom" param of its own type
|
||||||
|
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||||
|
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||||
|
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||||
|
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
|
||||||
|
for inp in inputs:
|
||||||
|
param = kwargs.pop(inp.name, None)
|
||||||
|
if param:
|
||||||
|
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
||||||
|
input_params[inp.name] = param
|
||||||
|
mellon_name = param.pop("name", inp.name)
|
||||||
|
if mellon_name != inp.name:
|
||||||
|
self.name_mapping[inp.name] = mellon_name
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if inp.name in DEFAULT_PARAM_MAPS:
|
||||||
|
# first check if it's in the default param map, if so, directly use that
|
||||||
|
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||||
|
elif get_group_name(inp.name):
|
||||||
|
param = get_group_name(inp.name)
|
||||||
|
if inp.name not in self.name_mapping:
|
||||||
|
self.name_mapping[inp.name] = param
|
||||||
|
else:
|
||||||
|
# if not, check if it's in the SDXL input schema, if so,
|
||||||
|
# 1. use the type hint to determine the type
|
||||||
|
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||||
|
if inp.type_hint is not None:
|
||||||
|
type_str = str(inp.type_hint).lower()
|
||||||
|
else:
|
||||||
|
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
||||||
|
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
||||||
|
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
||||||
|
if type_key in type_str:
|
||||||
|
param = type_param.copy()
|
||||||
|
param["label"] = inp.name
|
||||||
|
param["display"] = "input"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = inp.name
|
||||||
|
# add the param dict to the inp_params dict
|
||||||
|
input_params[inp.name] = param
|
||||||
|
|
||||||
|
|
||||||
|
component_params = {}
|
||||||
|
for comp in self.blocks.expected_components:
|
||||||
|
param = kwargs.pop(comp.name, None)
|
||||||
|
if param:
|
||||||
|
component_params[comp.name] = param
|
||||||
|
mellon_name = param.pop("name", comp.name)
|
||||||
|
if mellon_name != comp.name:
|
||||||
|
self.name_mapping[comp.name] = mellon_name
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_exclude = False
|
||||||
|
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
||||||
|
if exclude_key in comp.name:
|
||||||
|
to_exclude = True
|
||||||
|
break
|
||||||
|
if to_exclude:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if get_group_name(comp.name):
|
||||||
|
param = get_group_name(comp.name)
|
||||||
|
if comp.name not in self.name_mapping:
|
||||||
|
self.name_mapping[comp.name] = param
|
||||||
|
elif comp.name in DEFAULT_MODEL_KEYS:
|
||||||
|
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
||||||
|
else:
|
||||||
|
param = comp.name
|
||||||
|
# add the param dict to the model_params dict
|
||||||
|
component_params[comp.name] = param
|
||||||
|
|
||||||
|
output_params = {}
|
||||||
|
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||||
|
last_block_name = list(self.blocks.blocks.keys())[-1]
|
||||||
|
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
|
||||||
|
else:
|
||||||
|
outputs = self.blocks.intermediates_outputs
|
||||||
|
|
||||||
|
for out in outputs:
|
||||||
|
param = kwargs.pop(out.name, None)
|
||||||
|
if param:
|
||||||
|
output_params[out.name] = param
|
||||||
|
mellon_name = param.pop("name", out.name)
|
||||||
|
if mellon_name != out.name:
|
||||||
|
self.name_mapping[out.name] = mellon_name
|
||||||
|
continue
|
||||||
|
|
||||||
|
if out.name in DEFAULT_PARAM_MAPS:
|
||||||
|
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
||||||
|
param["display"] = "output"
|
||||||
|
else:
|
||||||
|
group_name = get_group_name(out.name)
|
||||||
|
if group_name:
|
||||||
|
param = group_name
|
||||||
|
if out.name not in self.name_mapping:
|
||||||
|
self.name_mapping[out.name] = param
|
||||||
|
else:
|
||||||
|
param = out.name
|
||||||
|
# add the param dict to the outputs dict
|
||||||
|
output_params[out.name] = param
|
||||||
|
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
logger.warning(f"Unused kwargs: {kwargs}")
|
||||||
|
|
||||||
|
register_dict = {
|
||||||
|
"category": category,
|
||||||
|
"label": label,
|
||||||
|
"input_params": input_params,
|
||||||
|
"component_params": component_params,
|
||||||
|
"output_params": output_params,
|
||||||
|
"name_mapping": self.name_mapping,
|
||||||
|
}
|
||||||
|
self.register_to_config(**register_dict)
|
||||||
|
|
||||||
|
def setup(self, components, collection=None):
|
||||||
|
self.blocks.setup_loader(component_manager=components, collection=collection)
|
||||||
|
self._components_manager = components
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mellon_config(self):
|
||||||
|
return self._convert_to_mellon_config()
|
||||||
|
|
||||||
|
def _convert_to_mellon_config(self):
|
||||||
|
|
||||||
|
node = {}
|
||||||
|
node["label"] = self.config.label
|
||||||
|
node["category"] = self.config.category
|
||||||
|
|
||||||
|
node_param = {}
|
||||||
|
for inp_name, inp_param in self.config.input_params.items():
|
||||||
|
if inp_name in self.name_mapping:
|
||||||
|
mellon_name = self.name_mapping[inp_name]
|
||||||
|
else:
|
||||||
|
mellon_name = inp_name
|
||||||
|
if isinstance(inp_param, str):
|
||||||
|
param = {
|
||||||
|
"label": inp_param,
|
||||||
|
"type": inp_param,
|
||||||
|
"display": "input",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
param = inp_param
|
||||||
|
|
||||||
|
if mellon_name not in node_param:
|
||||||
|
node_param[mellon_name] = param
|
||||||
|
else:
|
||||||
|
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||||
|
|
||||||
|
|
||||||
|
for comp_name, comp_param in self.config.component_params.items():
|
||||||
|
if comp_name in self.name_mapping:
|
||||||
|
mellon_name = self.name_mapping[comp_name]
|
||||||
|
else:
|
||||||
|
mellon_name = comp_name
|
||||||
|
if isinstance(comp_param, str):
|
||||||
|
param = {
|
||||||
|
"label": comp_param,
|
||||||
|
"type": comp_param,
|
||||||
|
"display": "input",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
param = comp_param
|
||||||
|
|
||||||
|
if mellon_name not in node_param:
|
||||||
|
node_param[mellon_name] = param
|
||||||
|
else:
|
||||||
|
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||||
|
|
||||||
|
|
||||||
|
for out_name, out_param in self.config.output_params.items():
|
||||||
|
if out_name in self.name_mapping:
|
||||||
|
mellon_name = self.name_mapping[out_name]
|
||||||
|
else:
|
||||||
|
mellon_name = out_name
|
||||||
|
if isinstance(out_param, str):
|
||||||
|
param = {
|
||||||
|
"label": out_param,
|
||||||
|
"type": out_param,
|
||||||
|
"display": "output",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
param = out_param
|
||||||
|
|
||||||
|
if mellon_name not in node_param:
|
||||||
|
node_param[mellon_name] = param
|
||||||
|
else:
|
||||||
|
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
||||||
|
node["params"] = node_param
|
||||||
|
return node
|
||||||
|
|
||||||
|
def save_mellon_config(self, file_path):
|
||||||
|
"""
|
||||||
|
Save the Mellon configuration to a JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str or Path): Path where the JSON file will be saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the saved config file
|
||||||
|
"""
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(file_path.parent, exist_ok=True)
|
||||||
|
|
||||||
|
# Create a combined dictionary with module definition and name mapping
|
||||||
|
config = {
|
||||||
|
"module": self.mellon_config,
|
||||||
|
"name_mapping": self.name_mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save the config to file
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_mellon_config(cls, file_path):
|
||||||
|
"""
|
||||||
|
Load a Mellon configuration from a JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str or Path): Path to the JSON file containing Mellon config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||||
|
"""
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
logger.info(f"Mellon config loaded from {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def process_inputs(self, **kwargs):
|
||||||
|
|
||||||
|
params_components = {}
|
||||||
|
for comp_name, comp_param in self.config.component_params.items():
|
||||||
|
logger.debug(f"component: {comp_name}")
|
||||||
|
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
||||||
|
if mellon_comp_name in kwargs:
|
||||||
|
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
||||||
|
comp = kwargs[mellon_comp_name].pop(comp_name)
|
||||||
|
else:
|
||||||
|
comp = kwargs.pop(mellon_comp_name)
|
||||||
|
if comp:
|
||||||
|
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||||
|
|
||||||
|
|
||||||
|
params_run = {}
|
||||||
|
for inp_name, inp_param in self.config.input_params.items():
|
||||||
|
logger.debug(f"input: {inp_name}")
|
||||||
|
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
||||||
|
if mellon_inp_name in kwargs:
|
||||||
|
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
||||||
|
inp = kwargs[mellon_inp_name].pop(inp_name)
|
||||||
|
else:
|
||||||
|
inp = kwargs.pop(mellon_inp_name)
|
||||||
|
if inp is not None:
|
||||||
|
params_run[inp_name] = inp
|
||||||
|
|
||||||
|
return_output_names = list(self.config.output_params.keys())
|
||||||
|
|
||||||
|
return params_components, params_run, return_output_names
|
||||||
|
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||||
|
|
||||||
|
self.blocks.loader.update(**params_components)
|
||||||
|
output = self.blocks.run(**params_run, output=return_output_names)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"]
|
||||||
|
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
|
||||||
|
_import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"]
|
||||||
|
_import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"]
|
||||||
|
_import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
|
||||||
|
from .modular_loader import StableDiffusionXLModularLoader
|
||||||
|
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
|
||||||
|
from .decoders import StableDiffusionXLAutoDecodeStep
|
||||||
|
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,215 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
||||||
|
from ...models import AutoencoderKL
|
||||||
|
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
|
from ..modular_pipeline import (
|
||||||
|
AutoPipelineBlocks,
|
||||||
|
PipelineBlock,
|
||||||
|
PipelineState,
|
||||||
|
SequentialPipelineBlocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||||
|
|
||||||
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKL),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
VaeImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8}),
|
||||||
|
default_creation_method="from_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that decodes the denoised latents into images"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("output_type", default="pil"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_inputs(self) -> List[str]:
|
||||||
|
return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_outputs(self) -> List[str]:
|
||||||
|
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")]
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components
|
||||||
|
@staticmethod
|
||||||
|
def upcast_vae(components):
|
||||||
|
dtype = components.vae.dtype
|
||||||
|
components.vae.to(dtype=torch.float32)
|
||||||
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
|
components.vae.decoder.mid_block.attentions[0].processor,
|
||||||
|
(
|
||||||
|
AttnProcessor2_0,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
# to be in float32 which can save lots of memory
|
||||||
|
if use_torch_2_0_or_xformers:
|
||||||
|
components.vae.post_quant_conv.to(dtype)
|
||||||
|
components.vae.decoder.conv_in.to(dtype)
|
||||||
|
components.vae.decoder.mid_block.to(dtype)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
if not block_state.output_type == "latent":
|
||||||
|
latents = block_state.latents
|
||||||
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||||
|
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
|
||||||
|
|
||||||
|
if block_state.needs_upcasting:
|
||||||
|
self.upcast_vae(components)
|
||||||
|
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
|
||||||
|
elif latents.dtype != components.vae.dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
components.vae = components.vae.to(latents.dtype)
|
||||||
|
|
||||||
|
# unscale/denormalize the latents
|
||||||
|
# denormalize with the mean and std if available and not None
|
||||||
|
block_state.has_latents_mean = (
|
||||||
|
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
|
||||||
|
)
|
||||||
|
block_state.has_latents_std = (
|
||||||
|
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
|
||||||
|
)
|
||||||
|
if block_state.has_latents_mean and block_state.has_latents_std:
|
||||||
|
block_state.latents_mean = (
|
||||||
|
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||||
|
)
|
||||||
|
block_state.latents_std = (
|
||||||
|
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||||
|
)
|
||||||
|
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
|
||||||
|
else:
|
||||||
|
latents = latents / components.vae.config.scaling_factor
|
||||||
|
|
||||||
|
block_state.images = components.vae.decode(latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
# cast back to fp16 if needed
|
||||||
|
if block_state.needs_upcasting:
|
||||||
|
components.vae.to(dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
block_state.images = block_state.latents
|
||||||
|
|
||||||
|
# apply watermark if available
|
||||||
|
if hasattr(components, "watermark") and components.watermark is not None:
|
||||||
|
block_state.images = components.watermark.apply_watermark(block_state.images)
|
||||||
|
|
||||||
|
block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type)
|
||||||
|
|
||||||
|
self.add_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||||
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \
|
||||||
|
"only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("image", required=True),
|
||||||
|
InputParam("mask_image", required=True),
|
||||||
|
InputParam("padding_mask_crop"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"),
|
||||||
|
InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.")
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_outputs(self) -> List[str]:
|
||||||
|
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
|
||||||
|
block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images]
|
||||||
|
|
||||||
|
self.add_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
|
||||||
|
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
|
||||||
|
block_names = ["decode", "mask_overlay"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \
|
||||||
|
"This is a sequential pipeline blocks:\n" + \
|
||||||
|
" - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \
|
||||||
|
" - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
|
||||||
|
block_names = ["inpaint", "non-inpaint"]
|
||||||
|
block_trigger_inputs = ["padding_mask_crop", None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Decode step that decode the denoised latents into images outputs.\n" + \
|
||||||
|
"This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \
|
||||||
|
" - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \
|
||||||
|
" - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
|
||||||
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,858 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
||||||
|
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
||||||
|
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
|
||||||
|
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
|
from ...models.lora import adjust_lora_scale_text_encoder
|
||||||
|
from ...utils import (
|
||||||
|
USE_PEFT_BACKEND,
|
||||||
|
logging,
|
||||||
|
scale_lora_layers,
|
||||||
|
unscale_lora_layers,
|
||||||
|
)
|
||||||
|
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||||
|
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPImageProcessor,
|
||||||
|
CLIPTextModelWithProjection,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPVisionModelWithProjection,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...schedulers import EulerDiscreteScheduler
|
||||||
|
from ...guiders import ClassifierFreeGuidance
|
||||||
|
|
||||||
|
from .modular_loader import StableDiffusionXLModularLoader
|
||||||
|
from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||||
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
|
||||||
|
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
||||||
|
" for more details"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
|
||||||
|
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
|
||||||
|
ComponentSpec("unet", UNet2DConditionModel),
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 7.5}),
|
||||||
|
default_creation_method="from_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"ip_adapter_image",
|
||||||
|
PipelineImageInput,
|
||||||
|
required=True,
|
||||||
|
description="The image(s) to be used as ip adapter"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
||||||
|
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
|
||||||
|
@staticmethod
|
||||||
|
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||||
|
dtype = next(components.image_encoder.parameters()).dtype
|
||||||
|
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
image = components.feature_extractor(image, return_tensors="pt").pixel_values
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
if output_hidden_states:
|
||||||
|
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||||
|
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
|
uncond_image_enc_hidden_states = components.image_encoder(
|
||||||
|
torch.zeros_like(image), output_hidden_states=True
|
||||||
|
).hidden_states[-2]
|
||||||
|
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||||
|
num_images_per_prompt, dim=0
|
||||||
|
)
|
||||||
|
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||||
|
else:
|
||||||
|
image_embeds = components.image_encoder(image).image_embeds
|
||||||
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
|
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||||
|
|
||||||
|
return image_embeds, uncond_image_embeds
|
||||||
|
|
||||||
|
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||||
|
def prepare_ip_adapter_image_embeds(
|
||||||
|
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
|
||||||
|
):
|
||||||
|
image_embeds = []
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
negative_image_embeds = []
|
||||||
|
if ip_adapter_image_embeds is None:
|
||||||
|
if not isinstance(ip_adapter_image, list):
|
||||||
|
ip_adapter_image = [ip_adapter_image]
|
||||||
|
|
||||||
|
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||||
|
)
|
||||||
|
|
||||||
|
for single_ip_adapter_image, image_proj_layer in zip(
|
||||||
|
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
|
||||||
|
):
|
||||||
|
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||||
|
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||||
|
components, single_ip_adapter_image, device, 1, output_hidden_state
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds.append(single_image_embeds[None, :])
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||||
|
else:
|
||||||
|
for single_image_embeds in ip_adapter_image_embeds:
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||||
|
negative_image_embeds.append(single_negative_image_embeds)
|
||||||
|
image_embeds.append(single_image_embeds)
|
||||||
|
|
||||||
|
ip_adapter_image_embeds = []
|
||||||
|
for i, single_image_embeds in enumerate(image_embeds):
|
||||||
|
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||||
|
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||||
|
|
||||||
|
single_image_embeds = single_image_embeds.to(device=device)
|
||||||
|
ip_adapter_image_embeds.append(single_image_embeds)
|
||||||
|
|
||||||
|
return ip_adapter_image_embeds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
|
||||||
|
components,
|
||||||
|
ip_adapter_image=block_state.ip_adapter_image,
|
||||||
|
ip_adapter_image_embeds=None,
|
||||||
|
device=block_state.device,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
|
||||||
|
)
|
||||||
|
if block_state.prepare_unconditional_embeds:
|
||||||
|
block_state.negative_ip_adapter_embeds = []
|
||||||
|
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
|
||||||
|
negative_image_embeds, image_embeds = image_embeds.chunk(2)
|
||||||
|
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
|
||||||
|
block_state.ip_adapter_embeds[i] = image_embeds
|
||||||
|
|
||||||
|
self.add_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||||
|
|
||||||
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return(
|
||||||
|
"Text Encoder step that generate text_embeddings to guide the image generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("text_encoder", CLIPTextModel),
|
||||||
|
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||||
|
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||||
|
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 7.5}),
|
||||||
|
default_creation_method="from_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_configs(self) -> List[ConfigSpec]:
|
||||||
|
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("prompt_2"),
|
||||||
|
InputParam("negative_prompt"),
|
||||||
|
InputParam("negative_prompt_2"),
|
||||||
|
InputParam("cross_attention_kwargs"),
|
||||||
|
InputParam("clip_skip"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"),
|
||||||
|
OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"),
|
||||||
|
OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"),
|
||||||
|
OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
|
||||||
|
if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||||
|
elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)):
|
||||||
|
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_prompt(
|
||||||
|
components,
|
||||||
|
prompt: str,
|
||||||
|
prompt_2: Optional[str] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
prepare_unconditional_embeds: bool = True,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
negative_prompt_2: Optional[str] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
lora_scale: Optional[float] = None,
|
||||||
|
clip_skip: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
prompt_2 (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||||
|
used in both text-encoders
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
num_images_per_prompt (`int`):
|
||||||
|
number of images that should be generated per prompt
|
||||||
|
prepare_unconditional_embeds (`bool`):
|
||||||
|
whether to use prepare unconditional embeddings or not
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||||
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||||
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||||
|
input argument.
|
||||||
|
lora_scale (`float`, *optional*):
|
||||||
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||||
|
clip_skip (`int`, *optional*):
|
||||||
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||||
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||||
|
"""
|
||||||
|
device = device or components._execution_device
|
||||||
|
|
||||||
|
# set lora scale so that monkey patched LoRA
|
||||||
|
# function of text encoder can correctly access it
|
||||||
|
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
|
||||||
|
components._lora_scale = lora_scale
|
||||||
|
|
||||||
|
# dynamically adjust the LoRA scale
|
||||||
|
if components.text_encoder is not None:
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
|
||||||
|
else:
|
||||||
|
scale_lora_layers(components.text_encoder, lora_scale)
|
||||||
|
|
||||||
|
if components.text_encoder_2 is not None:
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
|
||||||
|
else:
|
||||||
|
scale_lora_layers(components.text_encoder_2, lora_scale)
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
if prompt is not None:
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
# Define tokenizers and text encoders
|
||||||
|
tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2]
|
||||||
|
text_encoders = (
|
||||||
|
[components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_2 = prompt_2 or prompt
|
||||||
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||||
|
|
||||||
|
# textual inversion: process multi-vector tokens if necessary
|
||||||
|
prompt_embeds_list = []
|
||||||
|
prompts = [prompt, prompt_2]
|
||||||
|
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
||||||
|
if isinstance(components, TextualInversionLoaderMixin):
|
||||||
|
prompt = components.maybe_convert_prompt(prompt, tokenizer)
|
||||||
|
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||||
|
text_input_ids, untruncated_ids
|
||||||
|
):
|
||||||
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||||
|
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||||
|
|
||||||
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||||
|
pooled_prompt_embeds = prompt_embeds[0]
|
||||||
|
if clip_skip is None:
|
||||||
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||||
|
else:
|
||||||
|
# "2" because SDXL always indexes from the penultimate layer.
|
||||||
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||||
|
|
||||||
|
prompt_embeds_list.append(prompt_embeds)
|
||||||
|
|
||||||
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||||
|
|
||||||
|
# get unconditional embeddings for classifier free guidance
|
||||||
|
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
|
||||||
|
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||||
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||||
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||||
|
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||||
|
negative_prompt = negative_prompt or ""
|
||||||
|
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||||
|
|
||||||
|
# normalize str to list
|
||||||
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
negative_prompt_2 = (
|
||||||
|
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||||
|
)
|
||||||
|
|
||||||
|
uncond_tokens: List[str]
|
||||||
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||||
|
|
||||||
|
negative_prompt_embeds_list = []
|
||||||
|
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
||||||
|
if isinstance(components, TextualInversionLoaderMixin):
|
||||||
|
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
|
||||||
|
|
||||||
|
max_length = prompt_embeds.shape[1]
|
||||||
|
uncond_input = tokenizer(
|
||||||
|
negative_prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_prompt_embeds = text_encoder(
|
||||||
|
uncond_input.input_ids.to(device),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||||
|
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||||
|
|
||||||
|
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||||
|
|
||||||
|
if components.text_encoder_2 is not None:
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
|
||||||
|
else:
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
||||||
|
|
||||||
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||||
|
seq_len = negative_prompt_embeds.shape[1]
|
||||||
|
|
||||||
|
if components.text_encoder_2 is not None:
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
|
||||||
|
else:
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
||||||
|
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||||
|
bs_embed * num_images_per_prompt, -1
|
||||||
|
)
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||||
|
bs_embed * num_images_per_prompt, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
if components.text_encoder is not None:
|
||||||
|
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||||
|
# Retrieve the original scale by scaling back the LoRA layers
|
||||||
|
unscale_lora_layers(components.text_encoder, lora_scale)
|
||||||
|
|
||||||
|
if components.text_encoder_2 is not None:
|
||||||
|
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||||
|
# Retrieve the original scale by scaling back the LoRA layers
|
||||||
|
unscale_lora_layers(components.text_encoder_2, lora_scale)
|
||||||
|
|
||||||
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||||
|
# Get inputs and intermediates
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
# Encode input prompt
|
||||||
|
block_state.text_encoder_lora_scale = (
|
||||||
|
block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None
|
||||||
|
)
|
||||||
|
(
|
||||||
|
block_state.prompt_embeds,
|
||||||
|
block_state.negative_prompt_embeds,
|
||||||
|
block_state.pooled_prompt_embeds,
|
||||||
|
block_state.negative_pooled_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
components,
|
||||||
|
block_state.prompt,
|
||||||
|
block_state.prompt_2,
|
||||||
|
block_state.device,
|
||||||
|
1,
|
||||||
|
block_state.prepare_unconditional_embeds,
|
||||||
|
block_state.negative_prompt,
|
||||||
|
block_state.negative_prompt_2,
|
||||||
|
prompt_embeds=None,
|
||||||
|
negative_prompt_embeds=None,
|
||||||
|
pooled_prompt_embeds=None,
|
||||||
|
negative_pooled_prompt_embeds=None,
|
||||||
|
lora_scale=block_state.text_encoder_lora_scale,
|
||||||
|
clip_skip=block_state.clip_skip,
|
||||||
|
)
|
||||||
|
# Add outputs
|
||||||
|
self.add_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||||
|
|
||||||
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Vae Encoder step that encode the input image into a latent representation"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKL),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
VaeImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8}),
|
||||||
|
default_creation_method="from_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image", required=True),
|
||||||
|
InputParam("height"),
|
||||||
|
InputParam("width"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("generator"),
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||||
|
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")]
|
||||||
|
|
||||||
|
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||||
|
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||||
|
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||||
|
|
||||||
|
latents_mean = latents_std = None
|
||||||
|
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||||
|
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||||
|
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||||
|
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||||
|
|
||||||
|
dtype = image.dtype
|
||||||
|
if components.vae.config.force_upcast:
|
||||||
|
image = image.float()
|
||||||
|
components.vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
if isinstance(generator, list):
|
||||||
|
image_latents = [
|
||||||
|
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||||
|
for i in range(image.shape[0])
|
||||||
|
]
|
||||||
|
image_latents = torch.cat(image_latents, dim=0)
|
||||||
|
else:
|
||||||
|
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||||
|
|
||||||
|
if components.vae.config.force_upcast:
|
||||||
|
components.vae.to(dtype)
|
||||||
|
|
||||||
|
image_latents = image_latents.to(dtype)
|
||||||
|
if latents_mean is not None and latents_std is not None:
|
||||||
|
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||||
|
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||||
|
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
|
||||||
|
else:
|
||||||
|
image_latents = components.vae.config.scaling_factor * image_latents
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||||
|
|
||||||
|
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs)
|
||||||
|
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||||
|
|
||||||
|
block_state.batch_size = block_state.image.shape[0]
|
||||||
|
|
||||||
|
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||||
|
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||||
|
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
|
||||||
|
|
||||||
|
self.add_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||||
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKL),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
VaeImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8}),
|
||||||
|
default_creation_method="from_config"),
|
||||||
|
ComponentSpec(
|
||||||
|
"mask_processor",
|
||||||
|
VaeImageProcessor,
|
||||||
|
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
|
||||||
|
default_creation_method="from_config"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Vae encoder step that prepares the image and mask for the inpainting process"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("height"),
|
||||||
|
InputParam("width"),
|
||||||
|
InputParam("image", required=True),
|
||||||
|
InputParam("mask_image", required=True),
|
||||||
|
InputParam("padding_mask_crop"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||||
|
InputParam("generator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediates_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
|
||||||
|
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
|
||||||
|
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
|
||||||
|
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
|
||||||
|
|
||||||
|
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||||
|
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||||
|
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||||
|
|
||||||
|
latents_mean = latents_std = None
|
||||||
|
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||||
|
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||||
|
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||||
|
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||||
|
|
||||||
|
dtype = image.dtype
|
||||||
|
if components.vae.config.force_upcast:
|
||||||
|
image = image.float()
|
||||||
|
components.vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
if isinstance(generator, list):
|
||||||
|
image_latents = [
|
||||||
|
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||||
|
for i in range(image.shape[0])
|
||||||
|
]
|
||||||
|
image_latents = torch.cat(image_latents, dim=0)
|
||||||
|
else:
|
||||||
|
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||||
|
|
||||||
|
if components.vae.config.force_upcast:
|
||||||
|
components.vae.to(dtype)
|
||||||
|
|
||||||
|
image_latents = image_latents.to(dtype)
|
||||||
|
if latents_mean is not None and latents_std is not None:
|
||||||
|
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||||
|
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||||
|
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||||
|
else:
|
||||||
|
image_latents = components.vae.config.scaling_factor * image_latents
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
|
||||||
|
# do not accept do_classifier_free_guidance
|
||||||
|
def prepare_mask_latents(
|
||||||
|
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
|
||||||
|
):
|
||||||
|
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||||
|
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||||
|
# and half precision
|
||||||
|
mask = torch.nn.functional.interpolate(
|
||||||
|
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
|
||||||
|
)
|
||||||
|
mask = mask.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||||
|
if mask.shape[0] < batch_size:
|
||||||
|
if not batch_size % mask.shape[0] == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||||
|
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||||
|
" of masks that you pass is divisible by the total requested batch size."
|
||||||
|
)
|
||||||
|
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||||
|
|
||||||
|
if masked_image is not None and masked_image.shape[1] == 4:
|
||||||
|
masked_image_latents = masked_image
|
||||||
|
else:
|
||||||
|
masked_image_latents = None
|
||||||
|
|
||||||
|
if masked_image is not None:
|
||||||
|
if masked_image_latents is None:
|
||||||
|
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||||
|
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
||||||
|
|
||||||
|
if masked_image_latents.shape[0] < batch_size:
|
||||||
|
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||||
|
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||||
|
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||||
|
)
|
||||||
|
masked_image_latents = masked_image_latents.repeat(
|
||||||
|
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# aligning device to prevent device errors when concating it with the latent model input
|
||||||
|
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return mask, masked_image_latents
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||||
|
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
if block_state.padding_mask_crop is not None:
|
||||||
|
block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop)
|
||||||
|
block_state.resize_mode = "fill"
|
||||||
|
else:
|
||||||
|
block_state.crops_coords = None
|
||||||
|
block_state.resize_mode = "default"
|
||||||
|
|
||||||
|
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode)
|
||||||
|
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords)
|
||||||
|
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||||
|
|
||||||
|
block_state.batch_size = block_state.image.shape[0]
|
||||||
|
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||||
|
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
|
||||||
|
|
||||||
|
# 7. Prepare mask latent variables
|
||||||
|
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||||
|
components,
|
||||||
|
block_state.mask,
|
||||||
|
block_state.masked_image,
|
||||||
|
block_state.batch_size,
|
||||||
|
block_state.height,
|
||||||
|
block_state.width,
|
||||||
|
block_state.dtype,
|
||||||
|
block_state.device,
|
||||||
|
block_state.generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_block_state(state, block_state)
|
||||||
|
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file)
|
||||||
|
# Encode
|
||||||
|
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
|
||||||
|
block_names = ["inpaint", "img2img"]
|
||||||
|
block_trigger_inputs = ["mask_image", "image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Vae encoder step that encode the image inputs into their latent representations.\n" + \
|
||||||
|
"This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \
|
||||||
|
" - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \
|
||||||
|
" - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin):
|
||||||
|
block_classes = [StableDiffusionXLIPAdapterStep]
|
||||||
|
block_names = ["ip_adapter"]
|
||||||
|
block_trigger_inputs = ["ip_adapter_image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Run IP Adapter step if `ip_adapter_image` is provided."
|
||||||
|
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ..modular_pipeline_utils import InsertableOrderedDict
|
||||||
|
|
||||||
|
# Import all the necessary block classes
|
||||||
|
from .denoise import (
|
||||||
|
StableDiffusionXLAutoDenoiseStep,
|
||||||
|
StableDiffusionXLControlNetDenoiseStep,
|
||||||
|
StableDiffusionXLDenoiseLoop,
|
||||||
|
StableDiffusionXLInpaintDenoiseLoop
|
||||||
|
)
|
||||||
|
from .before_denoise import (
|
||||||
|
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||||
|
StableDiffusionXLInputStep,
|
||||||
|
StableDiffusionXLSetTimestepsStep,
|
||||||
|
StableDiffusionXLPrepareLatentsStep,
|
||||||
|
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||||
|
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||||
|
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
||||||
|
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||||
|
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||||
|
StableDiffusionXLControlNetInputStep,
|
||||||
|
StableDiffusionXLControlNetUnionInputStep
|
||||||
|
)
|
||||||
|
from .encoders import (
|
||||||
|
StableDiffusionXLTextEncoderStep,
|
||||||
|
StableDiffusionXLAutoIPAdapterStep,
|
||||||
|
StableDiffusionXLAutoVaeEncoderStep,
|
||||||
|
StableDiffusionXLVaeEncoderStep,
|
||||||
|
StableDiffusionXLInpaintVaeEncoderStep,
|
||||||
|
StableDiffusionXLIPAdapterStep
|
||||||
|
)
|
||||||
|
from .decoders import (
|
||||||
|
StableDiffusionXLDecodeStep,
|
||||||
|
StableDiffusionXLInpaintDecodeStep,
|
||||||
|
StableDiffusionXLAutoDecodeStep
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# YiYi notes: comment out for now, work on this later
|
||||||
|
# block mapping
|
||||||
|
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
|
||||||
|
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||||
|
("input", StableDiffusionXLInputStep),
|
||||||
|
("set_timesteps", StableDiffusionXLSetTimestepsStep),
|
||||||
|
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
|
||||||
|
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
|
||||||
|
("denoise", StableDiffusionXLDenoiseLoop),
|
||||||
|
("decode", StableDiffusionXLDecodeStep)
|
||||||
|
])
|
||||||
|
|
||||||
|
IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([
|
||||||
|
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||||
|
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||||
|
("input", StableDiffusionXLInputStep),
|
||||||
|
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||||
|
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||||
|
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||||
|
("denoise", StableDiffusionXLDenoiseLoop),
|
||||||
|
("decode", StableDiffusionXLDecodeStep)
|
||||||
|
])
|
||||||
|
|
||||||
|
INPAINT_BLOCKS = InsertableOrderedDict([
|
||||||
|
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||||
|
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
|
||||||
|
("input", StableDiffusionXLInputStep),
|
||||||
|
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||||
|
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
|
||||||
|
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||||
|
("denoise", StableDiffusionXLInpaintDenoiseLoop),
|
||||||
|
("decode", StableDiffusionXLInpaintDecodeStep)
|
||||||
|
])
|
||||||
|
|
||||||
|
CONTROLNET_BLOCKS = InsertableOrderedDict([
|
||||||
|
("controlnet_input", StableDiffusionXLControlNetInputStep),
|
||||||
|
("denoise", StableDiffusionXLControlNetDenoiseStep),
|
||||||
|
])
|
||||||
|
|
||||||
|
CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([
|
||||||
|
("controlnet_input", StableDiffusionXLControlNetUnionInputStep),
|
||||||
|
("denoise", StableDiffusionXLControlNetDenoiseStep),
|
||||||
|
])
|
||||||
|
|
||||||
|
IP_ADAPTER_BLOCKS = InsertableOrderedDict([
|
||||||
|
("ip_adapter", StableDiffusionXLIPAdapterStep),
|
||||||
|
])
|
||||||
|
|
||||||
|
AUTO_BLOCKS = InsertableOrderedDict([
|
||||||
|
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||||
|
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||||
|
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
|
||||||
|
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
|
||||||
|
("denoise", StableDiffusionXLAutoDenoiseStep),
|
||||||
|
("decode", StableDiffusionXLAutoDecodeStep)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_SUPPORTED_BLOCKS = {
|
||||||
|
"text2img": TEXT2IMAGE_BLOCKS,
|
||||||
|
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||||
|
"inpaint": INPAINT_BLOCKS,
|
||||||
|
"controlnet": CONTROLNET_BLOCKS,
|
||||||
|
"controlnet_union": CONTROLNET_UNION_BLOCKS,
|
||||||
|
"ip_adapter": IP_ADAPTER_BLOCKS,
|
||||||
|
"auto": AUTO_BLOCKS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
||||||
|
from ...image_processor import PipelineImageInput
|
||||||
|
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||||
|
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
from ..modular_pipeline import ModularLoader
|
||||||
|
from ..modular_pipeline_utils import InputParam, OutputParam
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
|
||||||
|
# YiYi Notes: model specific components:
|
||||||
|
## (1) it should inherit from ModularLoader
|
||||||
|
## (2) acts like a container that holds components and configs
|
||||||
|
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
|
||||||
|
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
|
||||||
|
## (5) how to use together with Components_manager?
|
||||||
|
class StableDiffusionXLModularLoader(
|
||||||
|
ModularLoader,
|
||||||
|
StableDiffusionMixin,
|
||||||
|
TextualInversionLoaderMixin,
|
||||||
|
StableDiffusionXLLoraLoaderMixin,
|
||||||
|
ModularIPAdapterMixin,
|
||||||
|
):
|
||||||
|
@property
|
||||||
|
def default_sample_size(self):
|
||||||
|
default_sample_size = 128
|
||||||
|
if hasattr(self, "unet") and self.unet is not None:
|
||||||
|
default_sample_size = self.unet.config.sample_size
|
||||||
|
return default_sample_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor(self):
|
||||||
|
vae_scale_factor = 8
|
||||||
|
if hasattr(self, "vae") and self.vae is not None:
|
||||||
|
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
return vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_unet(self):
|
||||||
|
num_channels_unet = 4
|
||||||
|
if hasattr(self, "unet") and self.unet is not None:
|
||||||
|
num_channels_unet = self.unet.config.in_channels
|
||||||
|
return num_channels_unet
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_latents(self):
|
||||||
|
num_channels_latents = 4
|
||||||
|
if hasattr(self, "vae") and self.vae is not None:
|
||||||
|
num_channels_latents = self.vae.config.latent_channels
|
||||||
|
return num_channels_latents
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks
|
||||||
|
SDXL_INPUTS_SCHEMA = {
|
||||||
|
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
|
||||||
|
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
|
||||||
|
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
|
||||||
|
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
|
||||||
|
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
|
||||||
|
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
|
||||||
|
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
|
||||||
|
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
|
||||||
|
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
|
||||||
|
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||||
|
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||||
|
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
|
||||||
|
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
|
||||||
|
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
|
||||||
|
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
|
||||||
|
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
|
||||||
|
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||||
|
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
|
||||||
|
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
|
||||||
|
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
|
||||||
|
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
|
||||||
|
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
|
||||||
|
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
|
||||||
|
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
|
||||||
|
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
|
||||||
|
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
|
||||||
|
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
|
||||||
|
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
|
||||||
|
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
|
||||||
|
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||||
|
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
|
||||||
|
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
|
||||||
|
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
|
||||||
|
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
|
||||||
|
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
|
||||||
|
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
|
||||||
|
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
|
||||||
|
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||||
|
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
|
||||||
|
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
|
||||||
|
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
|
||||||
|
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
|
||||||
|
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||||
|
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||||
|
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
|
||||||
|
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
|
||||||
|
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||||
|
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
|
||||||
|
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
|
||||||
|
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
|
||||||
|
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||||
|
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
|
||||||
|
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
|
||||||
|
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
|
||||||
|
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||||
|
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||||
|
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||||
|
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
|
||||||
|
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
|
||||||
|
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
|
||||||
|
"prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"),
|
||||||
|
"negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
|
||||||
|
"pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"),
|
||||||
|
"negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
|
||||||
|
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
|
||||||
|
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||||
|
"image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"),
|
||||||
|
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
|
||||||
|
"masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
|
||||||
|
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||||
|
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
|
||||||
|
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
|
||||||
|
"latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"),
|
||||||
|
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
|
||||||
|
"negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
|
||||||
|
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||||
|
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
|
||||||
|
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||||
|
"ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
|
||||||
|
"negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
|
||||||
|
"images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_OUTPUTS_SCHEMA = {
|
||||||
|
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import SequentialPipelineBlocks
|
||||||
|
|
||||||
|
from .denoise import StableDiffusionXLAutoDenoiseStep
|
||||||
|
from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep
|
||||||
|
from .decoders import StableDiffusionXLAutoDecodeStep
|
||||||
|
from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks):
|
||||||
|
block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep]
|
||||||
|
block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \
|
||||||
|
"- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \
|
||||||
|
"- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \
|
||||||
|
"- to run the controlnet workflow, you need to provide `control_image`\n" + \
|
||||||
|
"- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \
|
||||||
|
"- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \
|
||||||
|
"- for text-to-image generation, all you need to provide is `prompt`"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +47,6 @@ else:
|
|||||||
"AutoPipelineForInpainting",
|
"AutoPipelineForInpainting",
|
||||||
"AutoPipelineForText2Image",
|
"AutoPipelineForText2Image",
|
||||||
]
|
]
|
||||||
_import_structure["modular_pipeline"] = ["ModularLoader"]
|
|
||||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||||
@@ -330,8 +329,6 @@ else:
|
|||||||
"StableDiffusionXLInpaintPipeline",
|
"StableDiffusionXLInpaintPipeline",
|
||||||
"StableDiffusionXLInstructPix2PixPipeline",
|
"StableDiffusionXLInstructPix2PixPipeline",
|
||||||
"StableDiffusionXLPipeline",
|
"StableDiffusionXLPipeline",
|
||||||
"StableDiffusionXLModularLoader",
|
|
||||||
"StableDiffusionXLAutoPipeline",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||||
@@ -481,7 +478,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
|
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
|
||||||
from .dit import DiTPipeline
|
from .dit import DiTPipeline
|
||||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||||
from .modular_pipeline import ModularLoader
|
|
||||||
from .pipeline_utils import (
|
from .pipeline_utils import (
|
||||||
AudioPipelineOutput,
|
AudioPipelineOutput,
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
@@ -706,9 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLImg2ImgPipeline,
|
StableDiffusionXLImg2ImgPipeline,
|
||||||
StableDiffusionXLInpaintPipeline,
|
StableDiffusionXLInpaintPipeline,
|
||||||
StableDiffusionXLInstructPix2PixPipeline,
|
StableDiffusionXLInstructPix2PixPipeline,
|
||||||
StableDiffusionXLModularLoader,
|
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
)
|
)
|
||||||
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
||||||
from .t2i_adapter import (
|
from .t2i_adapter import (
|
||||||
|
|||||||
@@ -29,18 +29,6 @@ else:
|
|||||||
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
|
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
|
||||||
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
|
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
|
||||||
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
|
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
|
||||||
_import_structure["pipeline_stable_diffusion_xl_modular"] = [
|
|
||||||
"StableDiffusionXLControlNetDenoiseStep",
|
|
||||||
"StableDiffusionXLDecodeLatentsStep",
|
|
||||||
"StableDiffusionXLDenoiseStep",
|
|
||||||
"StableDiffusionXLInputStep",
|
|
||||||
"StableDiffusionXLModularLoader",
|
|
||||||
"StableDiffusionXLPrepareAdditionalConditioningStep",
|
|
||||||
"StableDiffusionXLPrepareLatentsStep",
|
|
||||||
"StableDiffusionXLSetTimestepsStep",
|
|
||||||
"StableDiffusionXLTextEncoderStep",
|
|
||||||
"StableDiffusionXLAutoPipeline",
|
|
||||||
]
|
|
||||||
|
|
||||||
if is_transformers_available() and is_flax_available():
|
if is_transformers_available() and is_flax_available():
|
||||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||||
@@ -60,18 +48,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
||||||
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
||||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||||
from .pipeline_stable_diffusion_xl_modular import (
|
|
||||||
StableDiffusionXLControlNetDenoiseStep,
|
|
||||||
StableDiffusionXLDecodeLatentsStep,
|
|
||||||
StableDiffusionXLDenoiseStep,
|
|
||||||
StableDiffusionXLInputStep,
|
|
||||||
StableDiffusionXLModularLoader,
|
|
||||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
|
||||||
StableDiffusionXLPrepareLatentsStep,
|
|
||||||
StableDiffusionXLSetTimestepsStep,
|
|
||||||
StableDiffusionXLTextEncoderStep,
|
|
||||||
StableDiffusionXLAutoPipeline,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not (is_transformers_available() and is_flax_available()):
|
if not (is_transformers_available() and is_flax_available()):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -15,13 +15,16 @@
|
|||||||
"""Utilities to dynamically load objects from the Hub."""
|
"""Utilities to dynamically load objects from the Hub."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import signal
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
||||||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
||||||
|
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
|
||||||
|
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_diffusers_versions():
|
def get_diffusers_versions():
|
||||||
@@ -154,15 +159,87 @@ def check_imports(filename):
|
|||||||
return get_relative_imports(filename)
|
return get_relative_imports(filename)
|
||||||
|
|
||||||
|
|
||||||
def get_class_in_module(class_name, module_path):
|
def _raise_timeout_error(signum, frame):
|
||||||
|
raise ValueError(
|
||||||
|
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
||||||
|
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||||
|
if trust_remote_code is None:
|
||||||
|
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||||
|
prev_sig_handler = None
|
||||||
|
try:
|
||||||
|
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||||
|
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||||
|
while trust_remote_code is None:
|
||||||
|
answer = input(
|
||||||
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||||
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||||
|
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||||
|
f"Do you wish to run the custom code? [y/N] "
|
||||||
|
)
|
||||||
|
if answer.lower() in ["yes", "y", "1"]:
|
||||||
|
trust_remote_code = True
|
||||||
|
elif answer.lower() in ["no", "n", "0", ""]:
|
||||||
|
trust_remote_code = False
|
||||||
|
signal.alarm(0)
|
||||||
|
except Exception:
|
||||||
|
# OS which does not support signal.SIGALRM
|
||||||
|
raise ValueError(
|
||||||
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||||
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||||
|
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if prev_sig_handler is not None:
|
||||||
|
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||||
|
signal.alarm(0)
|
||||||
|
elif has_remote_code:
|
||||||
|
# For the CI which puts the timeout at 0
|
||||||
|
_raise_timeout_error(None, None)
|
||||||
|
|
||||||
|
if has_remote_code and not trust_remote_code:
|
||||||
|
raise ValueError(
|
||||||
|
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||||
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||||
|
" set the option `trust_remote_code=True` to remove this error."
|
||||||
|
)
|
||||||
|
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_in_module(class_name, module_path, force_reload=False):
|
||||||
"""
|
"""
|
||||||
Import a module on the cache directory for modules and extract a class from it.
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
"""
|
"""
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
name = os.path.normpath(module_path)
|
||||||
module = importlib.import_module(module_path)
|
if name.endswith(".py"):
|
||||||
|
name = name[:-3]
|
||||||
|
name = name.replace(os.path.sep, ".")
|
||||||
|
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||||
|
|
||||||
|
with _HF_REMOTE_CODE_LOCK:
|
||||||
|
if force_reload:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||||
|
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||||
|
|
||||||
|
module: ModuleType
|
||||||
|
if cached_module is None:
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
# insert it into sys.modules before any loading begins
|
||||||
|
sys.modules[name] = module
|
||||||
|
else:
|
||||||
|
module = cached_module
|
||||||
|
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
return find_pipeline_class(module)
|
return find_pipeline_class(module)
|
||||||
|
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
@@ -454,4 +531,4 @@ def get_class_from_dynamic_module(
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
return get_class_in_module(class_name, final_module)
|
||||||
|
|||||||
Reference in New Issue
Block a user