Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 07798babac | |||
| 9c7e205176 |
@@ -0,0 +1,65 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _BATCHED_INPUT_IDENTIFIERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CFG_PARALLEL = "cfg_parallel"
|
||||
|
||||
|
||||
class CFGParallelHook(ModelHook):
|
||||
def initialize_hook(self, module):
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Distributed environment not initialized.")
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
logger.warning(
|
||||
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
|
||||
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
|
||||
continue
|
||||
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
sample = output[0]
|
||||
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
|
||||
dist.all_gather(sample_list, sample)
|
||||
sample = torch.cat(sample_list, dim=0).contiguous()
|
||||
|
||||
return_dict = kwargs.get("return_dict", False)
|
||||
if not return_dict:
|
||||
return (sample, *output[1:])
|
||||
return output.__class__(sample, *output[1:])
|
||||
|
||||
|
||||
def apply_cfg_parallel(module: torch.nn.Module) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = CFGParallelHook()
|
||||
registry.register_hook(hook, _CFG_PARALLEL)
|
||||
@@ -0,0 +1,26 @@
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
_BATCHED_INPUT_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"pooled_projections",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
"guidance",
|
||||
)
|
||||
@@ -20,19 +20,18 @@ import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..utils import logging
|
||||
from ._common import (
|
||||
_ATTENTION_CLASSES,
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
)
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyramidAttentionBroadcastConfig:
|
||||
r"""
|
||||
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
|
||||
@@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
||||
motion_adapter: MotionAdapter,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
|
||||
@@ -246,7 +246,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
||||
motion_adapter: MotionAdapter,
|
||||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
||||
scheduler: Union[
|
||||
|
||||
@@ -232,8 +232,8 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
Tuple[HunyuanDiT2DControlNetModel],
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
],
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
|
||||
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
||||
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
||||
additional conditioning.
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
@@ -202,8 +202,8 @@ class StableDiffusion3ControlNetPipeline(
|
||||
controlnet: Union[
|
||||
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
|
||||
],
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
image_encoder: Optional[SiglipVisionModel] = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
|
||||
+4
-4
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
SiglipImageProcessor,
|
||||
SiglipModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -223,8 +223,8 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
controlnet: Union[
|
||||
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
|
||||
],
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
image_encoder: SiglipModel = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet1DModel
|
||||
from ...schedulers import SchedulerMixin
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import is_torch_xla_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
|
||||
super().__init__()
|
||||
|
||||
# make sure scheduler can always be converted to DDIM
|
||||
|
||||
@@ -17,6 +17,8 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import is_torch_xla_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
scheduler: RePaintScheduler
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -207,8 +207,8 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import urllib.parse as ul
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
@@ -144,13 +144,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`AutoModel`]):
|
||||
Frozen text-encoder. Lumina-T2I uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`AutoModel`):
|
||||
Tokenizer of class
|
||||
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
|
||||
text_encoder ([`GemmaPreTrainedModel`]):
|
||||
Frozen Gemma text-encoder.
|
||||
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
|
||||
Gemma tokenizer.
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
@@ -185,8 +182,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
|
||||
transformer: LuminaNextDiT2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: GemmaPreTrainedModel,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import Lumina2LoraLoaderMixin
|
||||
@@ -143,13 +143,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`AutoModel`]):
|
||||
Frozen text-encoder. Lumina-T2I uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`AutoModel`):
|
||||
Tokenizer of class
|
||||
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Frozen Gemma2 text-encoder.
|
||||
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
|
||||
Gemma tokenizer.
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
@@ -165,8 +162,8 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
transformer: Lumina2Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import warnings
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: AutoModelForCausalLM,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
|
||||
break
|
||||
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
|
||||
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
|
||||
|
||||
|
||||
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
|
||||
"""
|
||||
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
|
||||
the correct type as well.
|
||||
"""
|
||||
if not isinstance(class_or_tuple, tuple):
|
||||
class_or_tuple = (class_or_tuple,)
|
||||
|
||||
# Unpack unions
|
||||
unpacked_class_or_tuple = []
|
||||
for t in class_or_tuple:
|
||||
if get_origin(t) is Union:
|
||||
unpacked_class_or_tuple.extend(get_args(t))
|
||||
else:
|
||||
unpacked_class_or_tuple.append(t)
|
||||
class_or_tuple = tuple(unpacked_class_or_tuple)
|
||||
|
||||
if Any in class_or_tuple:
|
||||
return True
|
||||
|
||||
obj_type = type(obj)
|
||||
# Classes with obj's type
|
||||
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
|
||||
|
||||
# Singular types (e.g. int, ControlNet, ...)
|
||||
# Untyped collections (e.g. List, but not List[int])
|
||||
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
|
||||
if () in elem_class_or_tuple:
|
||||
return True
|
||||
# Typed lists or sets
|
||||
elif obj_type in (list, set):
|
||||
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
|
||||
# Typed tuples
|
||||
elif obj_type is tuple:
|
||||
return any(
|
||||
# Tuples with any length and single type (e.g. Tuple[int, ...])
|
||||
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
|
||||
or
|
||||
# Tuples with fixed length and any types (e.g. Tuple[int, str])
|
||||
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
|
||||
for t in elem_class_or_tuple
|
||||
)
|
||||
# Typed dicts
|
||||
elif obj_type is dict:
|
||||
return any(
|
||||
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
|
||||
for kt, vt in elem_class_or_tuple
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _get_detailed_type(obj: Any) -> Type:
|
||||
"""
|
||||
Gets a detailed type for an object, including nested types for collections.
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
|
||||
if obj_type in (list, set):
|
||||
obj_origin_type = List if obj_type is list else Set
|
||||
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
|
||||
return obj_origin_type[elems_type]
|
||||
elif obj_type is tuple:
|
||||
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
|
||||
elif obj_type is dict:
|
||||
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
|
||||
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
|
||||
return Dict[keys_type, values_type]
|
||||
else:
|
||||
return obj_type
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -79,10 +78,12 @@ from .pipeline_loading_utils import (
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_components_and_folders,
|
||||
_get_custom_pipeline_class,
|
||||
_get_detailed_type,
|
||||
_get_final_device_map,
|
||||
_get_ignore_patterns,
|
||||
_get_pipeline_class,
|
||||
_identify_model_variants,
|
||||
_is_valid_type,
|
||||
_maybe_raise_error_for_incorrect_transformers,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
@@ -876,26 +877,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for key in init_dict.keys():
|
||||
if key not in passed_class_obj:
|
||||
continue
|
||||
if "scheduler" in key:
|
||||
continue
|
||||
|
||||
class_obj = passed_class_obj[key]
|
||||
_expected_class_types = []
|
||||
for expected_type in expected_types[key]:
|
||||
if isinstance(expected_type, enum.EnumMeta):
|
||||
_expected_class_types.extend(expected_type.__members__.keys())
|
||||
else:
|
||||
_expected_class_types.append(expected_type.__name__)
|
||||
|
||||
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
|
||||
if not _is_valid_type:
|
||||
logger.warning(
|
||||
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
|
||||
)
|
||||
|
||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||
raise NotImplementedError(
|
||||
@@ -1015,10 +996,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 10. Instantiate the pipeline
|
||||
# 10. Type checking init arguments
|
||||
for kw, arg in init_kwargs.items():
|
||||
# Too complex to validate with type annotation alone
|
||||
if "scheduler" in kw:
|
||||
continue
|
||||
# Many tokenizer annotations don't include its "Fast" variant, so skip this
|
||||
# e.g T5Tokenizer but not T5TokenizerFast
|
||||
elif "tokenizer" in kw:
|
||||
continue
|
||||
elif (
|
||||
arg is not None # Skip if None
|
||||
and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations
|
||||
and not _is_valid_type(arg, expected_types[kw]) # Check type
|
||||
):
|
||||
logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")
|
||||
|
||||
# 11. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 11. Save where the model was instantiated from
|
||||
# 12. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
if device_map is not None:
|
||||
setattr(model, "hf_device_map", final_device_map)
|
||||
|
||||
@@ -20,7 +20,7 @@ import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: AutoModelForCausalLM,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
@@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The CLIP tokenizer.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
text_encoder (`CLIPTextModelWithProjection`):
|
||||
The CLIP text encoder.
|
||||
decoder ([`StableCascadeUNet`]):
|
||||
The Stable Cascade decoder unet.
|
||||
@@ -93,7 +93,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
self,
|
||||
decoder: StableCascadeUNet,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
latent_dim_scale: float = 10.67,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
@@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The decoder tokenizer to be used for text inputs.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
text_encoder (`CLIPTextModelWithProjection`):
|
||||
The decoder text encoder to be used for text inputs.
|
||||
decoder (`StableCascadeUNet`):
|
||||
The decoder model to be used for decoder image generation pipeline.
|
||||
@@ -60,14 +60,18 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
The scheduler to be used for decoder image generation pipeline.
|
||||
vqgan (`PaellaVQModel`):
|
||||
The VQGAN model to be used for decoder image generation pipeline.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
prior_prior (`StableCascadeUNet`):
|
||||
The prior model to be used for prior pipeline.
|
||||
prior_text_encoder (`CLIPTextModelWithProjection`):
|
||||
The prior text encoder to be used for text inputs.
|
||||
prior_tokenizer (`CLIPTokenizer`):
|
||||
The prior tokenizer to be used for text inputs.
|
||||
prior_scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for prior pipeline.
|
||||
prior_feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
"""
|
||||
|
||||
_load_connected_pipes = True
|
||||
@@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
decoder: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
prior_prior: StableCascadeUNet,
|
||||
prior_text_encoder: CLIPTextModel,
|
||||
prior_text_encoder: CLIPTextModelWithProjection,
|
||||
prior_tokenizer: CLIPTokenizer,
|
||||
prior_scheduler: DDPMWuerstchenScheduler,
|
||||
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
|
||||
@@ -141,7 +141,7 @@ class StableUnCLIPPipeline(
|
||||
image_noising_scheduler: KarrasDiffusionSchedulers,
|
||||
# regular denoising components
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
text_encoder: CLIPTextModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
# vae
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
@@ -197,8 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
image_encoder: SiglipVisionModel = None,
|
||||
feature_extractor: SiglipImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
||||
@@ -214,8 +218,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
image_encoder: Optional[SiglipVisionModel] = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
@@ -217,8 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
image_encoder: Optional[SiglipVisionModel] = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
+27
-11
@@ -19,15 +19,31 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPTokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import (
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
+2
-2
@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
||||
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
|
||||
|
||||
|
||||
class CustomLocalPipeline(DiffusionPipeline):
|
||||
@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
+2
-1
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import SchedulerMixin, UNet2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
@@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -91,10 +91,10 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
|
||||
text_encoder = Gemma2Model(config)
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"transformer": transformer,
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
Reference in New Issue
Block a user