Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 98771d3611 |
@@ -1,65 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _BATCHED_INPUT_IDENTIFIERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CFG_PARALLEL = "cfg_parallel"
|
||||
|
||||
|
||||
class CFGParallelHook(ModelHook):
|
||||
def initialize_hook(self, module):
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Distributed environment not initialized.")
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
logger.warning(
|
||||
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
|
||||
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
|
||||
continue
|
||||
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
sample = output[0]
|
||||
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
|
||||
dist.all_gather(sample_list, sample)
|
||||
sample = torch.cat(sample_list, dim=0).contiguous()
|
||||
|
||||
return_dict = kwargs.get("return_dict", False)
|
||||
if not return_dict:
|
||||
return (sample, *output[1:])
|
||||
return output.__class__(sample, *output[1:])
|
||||
|
||||
|
||||
def apply_cfg_parallel(module: torch.nn.Module) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = CFGParallelHook()
|
||||
registry.register_hook(hook, _CFG_PARALLEL)
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
@@ -14,13 +28,3 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
_BATCHED_INPUT_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"pooled_projections",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
"guidance",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .utils import _extract_return_information
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
|
||||
poorer generation quality. A lower threshold may not result in significant generation speedup. The
|
||||
threshold is compared against the absmean difference of the residuals between the current and cached
|
||||
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
|
||||
skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState:
|
||||
def __init__(self) -> None:
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.tail_block_residuals = None
|
||||
self.should_compute = True
|
||||
|
||||
def __repr__(self):
|
||||
return f"FirstBlockCacheSharedState(cache={self.cache})"
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
|
||||
self.shared_state = shared_state
|
||||
self.threshold = threshold
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
inputs = inspect.signature(module.__class__.forward)
|
||||
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
|
||||
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
|
||||
|
||||
try:
|
||||
outputs = _extract_return_information(module.__class__.forward)
|
||||
outputs_index_to_str = dict(enumerate(outputs))
|
||||
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
|
||||
except RuntimeError:
|
||||
logger.error(f"Failed to extract return information for {module.__class__}")
|
||||
raise NotImplementedError(
|
||||
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
|
||||
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
|
||||
f"in order for us to add support for this module."
|
||||
)
|
||||
|
||||
self._inputs_index_to_str = inputs_index_to_str
|
||||
self._inputs_str_to_index = inputs_str_to_index
|
||||
self._outputs_index_to_str = outputs_index_to_str
|
||||
self._outputs_str_to_index = outputs_str_to_index
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
|
||||
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
|
||||
original_hs = kwargs.get("hidden_states", None)
|
||||
original_ehs = kwargs.get("encoder_hidden_states", None)
|
||||
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
|
||||
if ehs_input_idx is not None:
|
||||
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
|
||||
|
||||
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
|
||||
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
|
||||
assert (ehs_input_idx is None) == (ehs_output_idx is None)
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
hs_residual = None
|
||||
if isinstance(output, tuple):
|
||||
hs_residual = output[hs_output_idx] - original_hs
|
||||
else:
|
||||
hs_residual = output - original_hs
|
||||
|
||||
should_compute = self._should_compute_remaining_blocks(hs_residual)
|
||||
self.shared_state.should_compute = should_compute
|
||||
|
||||
hs, ehs = None, None
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
logger.info("Skipping forward pass through remaining blocks")
|
||||
hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx]
|
||||
if ehs_output_idx is not None:
|
||||
ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx]
|
||||
|
||||
if isinstance(output, tuple):
|
||||
return_output = [None] * len(output)
|
||||
return_output[hs_output_idx] = hs
|
||||
return_output[ehs_output_idx] = ehs
|
||||
return_output = tuple(return_output)
|
||||
else:
|
||||
return_output = hs
|
||||
return return_output
|
||||
else:
|
||||
logger.info("Computing forward pass through remaining blocks")
|
||||
if isinstance(output, tuple):
|
||||
head_block_output = [None] * len(output)
|
||||
head_block_output[0] = output[hs_output_idx]
|
||||
head_block_output[1] = output[ehs_output_idx]
|
||||
else:
|
||||
head_block_output = output
|
||||
self.shared_state.head_block_output = head_block_output
|
||||
self.shared_state.head_block_residual = hs_residual
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.shared_state.reset()
|
||||
return module
|
||||
|
||||
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
|
||||
if self.shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hs_residual = self.shared_state.head_block_residual
|
||||
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
|
||||
prev_hs_mean = prev_hs_residual.abs().mean()
|
||||
diff = (hs_absmean / prev_hs_mean).item()
|
||||
logger.info(f"Diff: {diff}, Threshold: {self.threshold}")
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.shared_state = shared_state
|
||||
self.is_tail = is_tail
|
||||
|
||||
def initialize_hook(self, module):
|
||||
inputs = inspect.signature(module.__class__.forward)
|
||||
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
|
||||
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
|
||||
|
||||
try:
|
||||
outputs = _extract_return_information(module.__class__.forward)
|
||||
outputs_index_to_str = dict(enumerate(outputs))
|
||||
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
|
||||
except RuntimeError:
|
||||
logger.error(f"Failed to extract return information for {module.__class__}")
|
||||
raise NotImplementedError(
|
||||
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
|
||||
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
|
||||
f"in order for us to add support for this module."
|
||||
)
|
||||
|
||||
self._inputs_index_to_str = inputs_index_to_str
|
||||
self._inputs_str_to_index = inputs_str_to_index
|
||||
self._outputs_index_to_str = outputs_index_to_str
|
||||
self._outputs_str_to_index = outputs_str_to_index
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
|
||||
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
|
||||
original_hs = kwargs.get("hidden_states", None)
|
||||
original_ehs = kwargs.get("encoder_hidden_states", None)
|
||||
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
|
||||
if ehs_input_idx is not None:
|
||||
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
|
||||
|
||||
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
|
||||
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
|
||||
assert (ehs_input_idx is None) == (ehs_output_idx is None)
|
||||
|
||||
if self.shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hs_residual, ehs_residual = None, None
|
||||
if isinstance(output, tuple):
|
||||
hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0]
|
||||
ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1]
|
||||
else:
|
||||
hs_residual = output - self.shared_state.head_block_output
|
||||
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
|
||||
return output
|
||||
|
||||
output_count = len(self._outputs_index_to_str.keys())
|
||||
return_output = [None] * output_count if output_count > 1 else original_hs
|
||||
if output_count == 1:
|
||||
return_output = original_hs
|
||||
else:
|
||||
return_output[hs_output_idx] = original_hs
|
||||
return_output[ehs_output_idx] = original_ehs
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
shared_state = FBCSharedBlockState()
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for block in submodule:
|
||||
remaining_blocks.append((name, block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
|
||||
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Apply FBCBlockHook to '{name}'")
|
||||
apply_fbc_block_hook(block, shared_state)
|
||||
|
||||
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
|
||||
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||
|
||||
|
||||
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -0,0 +1,59 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import List
|
||||
|
||||
|
||||
def _extract_return_information(func) -> List[str]:
|
||||
"""Extracts return variable names in order from a function."""
|
||||
try:
|
||||
source = inspect.getsource(func)
|
||||
source = textwrap.dedent(source) # Modify indentation to make parsing compatible
|
||||
except (OSError, TypeError):
|
||||
try:
|
||||
source_file = inspect.getfile(func)
|
||||
with open(source_file, "r", encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
# Extract function definition manually
|
||||
source_lines = source.splitlines()
|
||||
func_name = func.__name__
|
||||
start_line = None
|
||||
indent_level = None
|
||||
extracted_lines = []
|
||||
|
||||
for i, line in enumerate(source_lines):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith(f"def {func_name}("):
|
||||
start_line = i
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
extracted_lines.append(line)
|
||||
continue
|
||||
|
||||
if start_line is not None:
|
||||
# Stop when indentation level decreases (end of function)
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
if current_indent <= indent_level and line.strip():
|
||||
break
|
||||
extracted_lines.append(line)
|
||||
|
||||
source = "\n".join(extracted_lines)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to retrieve function source: {e}")
|
||||
|
||||
# Parse source code using AST
|
||||
tree = ast.parse(source)
|
||||
return_vars = []
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
if isinstance(node.value, ast.Tuple):
|
||||
# Multiple return values
|
||||
return_vars.extend(var.id for var in node.value.elts if isinstance(var, ast.Name))
|
||||
elif isinstance(node.value, ast.Name):
|
||||
# Single return value
|
||||
return_vars.append(node.value.id)
|
||||
|
||||
visitor = ReturnVisitor()
|
||||
visitor.visit(tree)
|
||||
return return_vars
|
||||
@@ -87,10 +87,13 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
@@ -108,7 +111,10 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[encoder_hidden_states.size(1), hidden_states.size(1) - encoder_hidden_states.size(1)], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -224,7 +230,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
@@ -517,7 +523,7 @@ class FluxTransformer2DModel(
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
@@ -526,7 +532,7 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
@@ -545,20 +551,21 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
@@ -568,12 +575,7 @@ class FluxTransformer2DModel(
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
@@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
|
||||
@@ -246,7 +246,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
||||
scheduler: Union[
|
||||
|
||||
@@ -232,8 +232,8 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
Tuple[HunyuanDiT2DControlNetModel],
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
],
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
|
||||
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
||||
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
||||
additional conditioning.
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
@@ -202,8 +202,8 @@ class StableDiffusion3ControlNetPipeline(
|
||||
controlnet: Union[
|
||||
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
|
||||
],
|
||||
image_encoder: Optional[SiglipVisionModel] = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
|
||||
+4
-4
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipModel,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -223,8 +223,8 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
controlnet: Union[
|
||||
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
|
||||
],
|
||||
image_encoder: SiglipModel = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,8 +17,6 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet1DModel
|
||||
from ...schedulers import SchedulerMixin
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
@@ -51,7 +49,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import is_torch_xla_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -48,7 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
# make sure scheduler can always be converted to DDIM
|
||||
|
||||
@@ -17,8 +17,6 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import is_torch_xla_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
@@ -49,7 +47,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
scheduler: RePaintScheduler
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -207,8 +207,8 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2: Optional[T5EncoderModel] = None,
|
||||
tokenizer_2: Optional[MT5Tokenizer] = None,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import urllib.parse as ul
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
@@ -144,10 +144,13 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`GemmaPreTrainedModel`]):
|
||||
Frozen Gemma text-encoder.
|
||||
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
|
||||
Gemma tokenizer.
|
||||
text_encoder ([`AutoModel`]):
|
||||
Frozen text-encoder. Lumina-T2I uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`AutoModel`):
|
||||
Tokenizer of class
|
||||
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
@@ -182,8 +185,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
|
||||
transformer: LuminaNextDiT2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: GemmaPreTrainedModel,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import Lumina2LoraLoaderMixin
|
||||
@@ -143,10 +143,13 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Frozen Gemma2 text-encoder.
|
||||
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
|
||||
Gemma tokenizer.
|
||||
text_encoder ([`AutoModel`]):
|
||||
Frozen text-encoder. Lumina-T2I uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`AutoModel`):
|
||||
Tokenizer of class
|
||||
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
@@ -162,8 +165,8 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
transformer: Lumina2Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import warnings
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: AutoModelForCausalLM,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
|
||||
break
|
||||
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
|
||||
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
|
||||
|
||||
|
||||
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
|
||||
"""
|
||||
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
|
||||
the correct type as well.
|
||||
"""
|
||||
if not isinstance(class_or_tuple, tuple):
|
||||
class_or_tuple = (class_or_tuple,)
|
||||
|
||||
# Unpack unions
|
||||
unpacked_class_or_tuple = []
|
||||
for t in class_or_tuple:
|
||||
if get_origin(t) is Union:
|
||||
unpacked_class_or_tuple.extend(get_args(t))
|
||||
else:
|
||||
unpacked_class_or_tuple.append(t)
|
||||
class_or_tuple = tuple(unpacked_class_or_tuple)
|
||||
|
||||
if Any in class_or_tuple:
|
||||
return True
|
||||
|
||||
obj_type = type(obj)
|
||||
# Classes with obj's type
|
||||
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
|
||||
|
||||
# Singular types (e.g. int, ControlNet, ...)
|
||||
# Untyped collections (e.g. List, but not List[int])
|
||||
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
|
||||
if () in elem_class_or_tuple:
|
||||
return True
|
||||
# Typed lists or sets
|
||||
elif obj_type in (list, set):
|
||||
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
|
||||
# Typed tuples
|
||||
elif obj_type is tuple:
|
||||
return any(
|
||||
# Tuples with any length and single type (e.g. Tuple[int, ...])
|
||||
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
|
||||
or
|
||||
# Tuples with fixed length and any types (e.g. Tuple[int, str])
|
||||
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
|
||||
for t in elem_class_or_tuple
|
||||
)
|
||||
# Typed dicts
|
||||
elif obj_type is dict:
|
||||
return any(
|
||||
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
|
||||
for kt, vt in elem_class_or_tuple
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _get_detailed_type(obj: Any) -> Type:
|
||||
"""
|
||||
Gets a detailed type for an object, including nested types for collections.
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
|
||||
if obj_type in (list, set):
|
||||
obj_origin_type = List if obj_type is list else Set
|
||||
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
|
||||
return obj_origin_type[elems_type]
|
||||
elif obj_type is tuple:
|
||||
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
|
||||
elif obj_type is dict:
|
||||
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
|
||||
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
|
||||
return Dict[keys_type, values_type]
|
||||
else:
|
||||
return obj_type
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -78,12 +79,10 @@ from .pipeline_loading_utils import (
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_components_and_folders,
|
||||
_get_custom_pipeline_class,
|
||||
_get_detailed_type,
|
||||
_get_final_device_map,
|
||||
_get_ignore_patterns,
|
||||
_get_pipeline_class,
|
||||
_identify_model_variants,
|
||||
_is_valid_type,
|
||||
_maybe_raise_error_for_incorrect_transformers,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
@@ -877,6 +876,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for key in init_dict.keys():
|
||||
if key not in passed_class_obj:
|
||||
continue
|
||||
if "scheduler" in key:
|
||||
continue
|
||||
|
||||
class_obj = passed_class_obj[key]
|
||||
_expected_class_types = []
|
||||
for expected_type in expected_types[key]:
|
||||
if isinstance(expected_type, enum.EnumMeta):
|
||||
_expected_class_types.extend(expected_type.__members__.keys())
|
||||
else:
|
||||
_expected_class_types.append(expected_type.__name__)
|
||||
|
||||
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
|
||||
if not _is_valid_type:
|
||||
logger.warning(
|
||||
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
|
||||
)
|
||||
|
||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||
raise NotImplementedError(
|
||||
@@ -996,26 +1015,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 10. Type checking init arguments
|
||||
for kw, arg in init_kwargs.items():
|
||||
# Too complex to validate with type annotation alone
|
||||
if "scheduler" in kw:
|
||||
continue
|
||||
# Many tokenizer annotations don't include its "Fast" variant, so skip this
|
||||
# e.g T5Tokenizer but not T5TokenizerFast
|
||||
elif "tokenizer" in kw:
|
||||
continue
|
||||
elif (
|
||||
arg is not None # Skip if None
|
||||
and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations
|
||||
and not _is_valid_type(arg, expected_types[kw]) # Check type
|
||||
):
|
||||
logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")
|
||||
|
||||
# 11. Instantiate the pipeline
|
||||
# 10. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 12. Save where the model was instantiated from
|
||||
# 11. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
if device_map is not None:
|
||||
setattr(model, "hf_device_map", final_device_map)
|
||||
|
||||
@@ -20,7 +20,7 @@ import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: AutoModelForCausalLM,
|
||||
vae: AutoencoderDC,
|
||||
transformer: SanaTransformer2DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
@@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The CLIP tokenizer.
|
||||
text_encoder (`CLIPTextModelWithProjection`):
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The CLIP text encoder.
|
||||
decoder ([`StableCascadeUNet`]):
|
||||
The Stable Cascade decoder unet.
|
||||
@@ -93,7 +93,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
self,
|
||||
decoder: StableCascadeUNet,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
text_encoder: CLIPTextModel,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
latent_dim_scale: float = 10.67,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
@@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The decoder tokenizer to be used for text inputs.
|
||||
text_encoder (`CLIPTextModelWithProjection`):
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The decoder text encoder to be used for text inputs.
|
||||
decoder (`StableCascadeUNet`):
|
||||
The decoder model to be used for decoder image generation pipeline.
|
||||
@@ -60,18 +60,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
The scheduler to be used for decoder image generation pipeline.
|
||||
vqgan (`PaellaVQModel`):
|
||||
The VQGAN model to be used for decoder image generation pipeline.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
prior_prior (`StableCascadeUNet`):
|
||||
The prior model to be used for prior pipeline.
|
||||
prior_text_encoder (`CLIPTextModelWithProjection`):
|
||||
The prior text encoder to be used for text inputs.
|
||||
prior_tokenizer (`CLIPTokenizer`):
|
||||
The prior tokenizer to be used for text inputs.
|
||||
prior_scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for prior pipeline.
|
||||
prior_feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
"""
|
||||
|
||||
_load_connected_pipes = True
|
||||
@@ -80,12 +76,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
text_encoder: CLIPTextModel,
|
||||
decoder: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
prior_prior: StableCascadeUNet,
|
||||
prior_text_encoder: CLIPTextModelWithProjection,
|
||||
prior_text_encoder: CLIPTextModel,
|
||||
prior_tokenizer: CLIPTokenizer,
|
||||
prior_scheduler: DDPMWuerstchenScheduler,
|
||||
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
|
||||
@@ -141,7 +141,7 @@ class StableUnCLIPPipeline(
|
||||
image_noising_scheduler: KarrasDiffusionSchedulers,
|
||||
# regular denoising components
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
# vae
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
@@ -197,8 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: SiglipVisionModel = None,
|
||||
feature_extractor: SiglipImageProcessor = None,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -197,10 +197,6 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
||||
@@ -218,8 +214,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: Optional[SiglipVisionModel] = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`SiglipVisionModel`, *optional*):
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`SiglipImageProcessor`, *optional*):
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
@@ -217,8 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: Optional[SiglipVisionModel] = None,
|
||||
feature_extractor: Optional[SiglipImageProcessor] = None,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
+11
-27
@@ -19,31 +19,15 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPTokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import (
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -111,13 +95,13 @@ class StableDiffusionKDiffusionPipeline(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
+2
-2
@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
|
||||
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
class CustomLocalPipeline(DiffusionPipeline):
|
||||
@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
+1
-2
@@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import SchedulerMixin, UNet2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
@@ -34,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -91,10 +91,10 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
|
||||
text_encoder = Gemma2Model(config)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
Reference in New Issue
Block a user