Compare commits

..

1 Commits

Author SHA1 Message Date
Aryan 98771d3611 update 2025-02-23 13:21:01 +01:00
31 changed files with 464 additions and 296 deletions
-65
View File
@@ -1,65 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
from ..utils import get_logger
from ._common import _BATCHED_INPUT_IDENTIFIERS
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_CFG_PARALLEL = "cfg_parallel"
class CFGParallelHook(ModelHook):
def initialize_hook(self, module):
if not dist.is_initialized():
raise RuntimeError("Distributed environment not initialized.")
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if len(args) > 0:
logger.warning(
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
)
world_size = dist.get_world_size()
rank = dist.get_rank()
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
for key in list(kwargs.keys()):
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
continue
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
output = self.fn_ref.original_forward(*args, **kwargs)
sample = output[0]
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
dist.all_gather(sample_list, sample)
sample = torch.cat(sample_list, dim=0).contiguous()
return_dict = kwargs.get("return_dict", False)
if not return_dict:
return (sample, *output[1:])
return output.__class__(sample, *output[1:])
def apply_cfg_parallel(module: torch.nn.Module) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = CFGParallelHook()
registry.register_hook(hook, _CFG_PARALLEL)
+14 -10
View File
@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention_processor import Attention, MochiAttention
@@ -14,13 +28,3 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
_BATCHED_INPUT_IDENTIFIERS = (
"hidden_states",
"encoder_hidden_states",
"pooled_projections",
"timestep",
"attention_mask",
"encoder_attention_mask",
"guidance",
)
+262
View File
@@ -0,0 +1,262 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from .hooks import HookRegistry, ModelHook
from .utils import _extract_return_information
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
poorer generation quality. A lower threshold may not result in significant generation speedup. The
threshold is compared against the absmean difference of the residuals between the current and cached
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState:
def __init__(self) -> None:
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
def __repr__(self):
return f"FirstBlockCacheSharedState(cache={self.cache})"
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
self.shared_state = shared_state
self.threshold = threshold
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
inputs = inspect.signature(module.__class__.forward)
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
try:
outputs = _extract_return_information(module.__class__.forward)
outputs_index_to_str = dict(enumerate(outputs))
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
except RuntimeError:
logger.error(f"Failed to extract return information for {module.__class__}")
raise NotImplementedError(
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
f"in order for us to add support for this module."
)
self._inputs_index_to_str = inputs_index_to_str
self._inputs_str_to_index = inputs_str_to_index
self._outputs_index_to_str = outputs_index_to_str
self._outputs_str_to_index = outputs_str_to_index
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
original_hs = kwargs.get("hidden_states", None)
original_ehs = kwargs.get("encoder_hidden_states", None)
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
if ehs_input_idx is not None:
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
assert (ehs_input_idx is None) == (ehs_output_idx is None)
output = self.fn_ref.original_forward(*args, **kwargs)
hs_residual = None
if isinstance(output, tuple):
hs_residual = output[hs_output_idx] - original_hs
else:
hs_residual = output - original_hs
should_compute = self._should_compute_remaining_blocks(hs_residual)
self.shared_state.should_compute = should_compute
hs, ehs = None, None
if not should_compute:
# Apply caching
logger.info("Skipping forward pass through remaining blocks")
hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx]
if ehs_output_idx is not None:
ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx]
if isinstance(output, tuple):
return_output = [None] * len(output)
return_output[hs_output_idx] = hs
return_output[ehs_output_idx] = ehs
return_output = tuple(return_output)
else:
return_output = hs
return return_output
else:
logger.info("Computing forward pass through remaining blocks")
if isinstance(output, tuple):
head_block_output = [None] * len(output)
head_block_output[0] = output[hs_output_idx]
head_block_output[1] = output[ehs_output_idx]
else:
head_block_output = output
self.shared_state.head_block_output = head_block_output
self.shared_state.head_block_residual = hs_residual
return output
def reset_state(self, module):
self.shared_state.reset()
return module
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
if self.shared_state.head_block_residual is None:
return True
prev_hs_residual = self.shared_state.head_block_residual
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
prev_hs_mean = prev_hs_residual.abs().mean()
diff = (hs_absmean / prev_hs_mean).item()
logger.info(f"Diff: {diff}, Threshold: {self.threshold}")
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
super().__init__()
self.shared_state = shared_state
self.is_tail = is_tail
def initialize_hook(self, module):
inputs = inspect.signature(module.__class__.forward)
inputs_index_to_str = dict(enumerate(inputs.parameters.keys()))
inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()}
try:
outputs = _extract_return_information(module.__class__.forward)
outputs_index_to_str = dict(enumerate(outputs))
outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()}
except RuntimeError:
logger.error(f"Failed to extract return information for {module.__class__}")
raise NotImplementedError(
f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at "
f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example "
f"in order for us to add support for this module."
)
self._inputs_index_to_str = inputs_index_to_str
self._inputs_str_to_index = inputs_str_to_index
self._outputs_index_to_str = outputs_index_to_str
self._outputs_str_to_index = outputs_str_to_index
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
hs_input_idx = self._inputs_str_to_index.get("hidden_states")
ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None)
original_hs = kwargs.get("hidden_states", None)
original_ehs = kwargs.get("encoder_hidden_states", None)
original_hs = original_hs if original_hs is not None else args[hs_input_idx]
if ehs_input_idx is not None:
original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx]
hs_output_idx = self._outputs_str_to_index.get("hidden_states")
ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None)
assert (ehs_input_idx is None) == (ehs_output_idx is None)
if self.shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hs_residual, ehs_residual = None, None
if isinstance(output, tuple):
hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0]
ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1]
else:
hs_residual = output - self.shared_state.head_block_output
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
return output
output_count = len(self._outputs_index_to_str.keys())
return_output = [None] * output_count if output_count > 1 else original_hs
if output_count == 1:
return_output = original_hs
else:
return_output[hs_output_idx] = original_hs
return_output[ehs_output_idx] = original_ehs
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
shared_state = FBCSharedBlockState()
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for block in submodule:
remaining_blocks.append((name, block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Apply FBCBlockHook to '{name}'")
apply_fbc_block_hook(block, shared_state)
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)
+59
View File
@@ -0,0 +1,59 @@
import ast
import inspect
import textwrap
from typing import List
def _extract_return_information(func) -> List[str]:
"""Extracts return variable names in order from a function."""
try:
source = inspect.getsource(func)
source = textwrap.dedent(source) # Modify indentation to make parsing compatible
except (OSError, TypeError):
try:
source_file = inspect.getfile(func)
with open(source_file, "r", encoding="utf-8") as f:
source = f.read()
# Extract function definition manually
source_lines = source.splitlines()
func_name = func.__name__
start_line = None
indent_level = None
extracted_lines = []
for i, line in enumerate(source_lines):
stripped = line.strip()
if stripped.startswith(f"def {func_name}("):
start_line = i
indent_level = len(line) - len(line.lstrip())
extracted_lines.append(line)
continue
if start_line is not None:
# Stop when indentation level decreases (end of function)
current_indent = len(line) - len(line.lstrip())
if current_indent <= indent_level and line.strip():
break
extracted_lines.append(line)
source = "\n".join(extracted_lines)
except Exception as e:
raise RuntimeError(f"Failed to retrieve function source: {e}")
# Parse source code using AST
tree = ast.parse(source)
return_vars = []
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
if isinstance(node.value, ast.Tuple):
# Multiple return values
return_vars.extend(var.id for var in node.value.elts if isinstance(var, ast.Name))
elif isinstance(node.value, ast.Name):
# Single return value
return_vars.append(node.value.id)
visitor = ReturnVisitor()
visitor.visit(tree)
return return_vars
@@ -87,10 +87,13 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -108,7 +111,10 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
encoder_hidden_states, hidden_states = hidden_states.split(
[encoder_hidden_states.size(1), hidden_states.size(1) - encoder_hidden_states.size(1)], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
@@ -224,7 +230,7 @@ class FluxTransformerBlock(nn.Module):
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
return hidden_states, encoder_hidden_states
class FluxTransformer2DModel(
@@ -517,7 +523,7 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
@@ -526,7 +532,7 @@ class FluxTransformer2DModel(
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
@@ -545,20 +551,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -568,12 +575,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
@@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipeline(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: Union[UNet2DConditionModel, UNetMotionModel],
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
@@ -246,7 +246,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: Union[UNet2DConditionModel, UNetMotionModel],
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
@@ -232,8 +232,8 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
Tuple[HunyuanDiT2DControlNetModel],
HunyuanDiT2DMultiControlNetModel,
],
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
image_encoder (`SiglipVisionModel`, *optional*):
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`SiglipImageProcessor`, *optional*):
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
@@ -202,8 +202,8 @@ class StableDiffusion3ControlNetPipeline(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
SiglipImageProcessor,
SiglipModel,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
@@ -223,8 +223,8 @@ class StableDiffusion3ControlNetInpaintingPipeline(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
image_encoder: SiglipModel = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -17,8 +17,6 @@ from typing import List, Optional, Tuple, Union
import torch
from ...models import UNet1DModel
from ...schedulers import SchedulerMixin
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -51,7 +49,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet"
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@@ -16,7 +16,6 @@ from typing import List, Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...schedulers import DDIMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
@@ -48,7 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet"
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
def __init__(self, unet, scheduler):
super().__init__()
# make sure scheduler can always be converted to DDIM
@@ -17,8 +17,6 @@ from typing import List, Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...schedulers import DDPMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -49,7 +47,7 @@ class DDPMPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet"
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
scheduler: RePaintScheduler
model_cpu_offload_seq = "unet"
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@@ -207,8 +207,8 @@ class HunyuanDiTPipeline(DiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
):
super().__init__()
@@ -20,7 +20,7 @@ import urllib.parse as ul
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from transformers import AutoModel, AutoTokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
@@ -144,10 +144,13 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`GemmaPreTrainedModel`]):
Frozen Gemma text-encoder.
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
Gemma tokenizer.
text_encoder ([`AutoModel`]):
Frozen text-encoder. Lumina-T2I uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
@@ -182,8 +185,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: GemmaPreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
):
super().__init__()
@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from transformers import AutoModel, AutoTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import Lumina2LoraLoaderMixin
@@ -143,10 +143,13 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Gemma2PreTrainedModel`]):
Frozen Gemma2 text-encoder.
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
Gemma tokenizer.
text_encoder ([`AutoModel`]):
Frozen text-encoder. Lumina-T2I uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
@@ -162,8 +165,8 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
transformer: Lumina2Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: Gemma2PreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
):
super().__init__()
@@ -20,7 +20,7 @@ import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
def __init__(
self,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: Gemma2PreTrainedModel,
tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,
vae: AutoencoderDC,
transformer: SanaTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
@@ -17,7 +17,7 @@ import os
import re
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
from typing import Any, Callable, Dict, List, Optional, Union
import requests
import torch
@@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
"""
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
the correct type as well.
"""
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)
# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)
if Any in class_or_tuple:
return True
obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)
else:
return False
def _get_detailed_type(obj: Any) -> Type:
"""
Gets a detailed type for an object, including nested types for collections.
"""
obj_type = type(obj)
if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type
+23 -20
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
@@ -78,12 +79,10 @@ from .pipeline_loading_utils import (
_fetch_class_library_tuple,
_get_custom_components_and_folders,
_get_custom_pipeline_class,
_get_detailed_type,
_get_final_device_map,
_get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
_is_valid_type,
_maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
@@ -877,6 +876,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for key in init_dict.keys():
if key not in passed_class_obj:
continue
if "scheduler" in key:
continue
class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if not _is_valid_type:
logger.warning(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
@@ -996,26 +1015,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# 10. Type checking init arguments
for kw, arg in init_kwargs.items():
# Too complex to validate with type annotation alone
if "scheduler" in kw:
continue
# Many tokenizer annotations don't include its "Fast" variant, so skip this
# e.g T5Tokenizer but not T5TokenizerFast
elif "tokenizer" in kw:
continue
elif (
arg is not None # Skip if None
and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations
and not _is_valid_type(arg, expected_types[kw]) # Check type
):
logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")
# 11. Instantiate the pipeline
# 10. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
# 12. Save where the model was instantiated from
# 11. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
@@ -20,7 +20,7 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
def __init__(
self,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: Gemma2PreTrainedModel,
tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,
vae: AutoencoderDC,
transformer: SanaTransformer2DModel,
scheduler: DPMSolverMultistepScheduler,
@@ -15,7 +15,7 @@
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
@@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
Args:
tokenizer (`CLIPTokenizer`):
The CLIP tokenizer.
text_encoder (`CLIPTextModelWithProjection`):
text_encoder (`CLIPTextModel`):
The CLIP text encoder.
decoder ([`StableCascadeUNet`]):
The Stable Cascade decoder unet.
@@ -93,7 +93,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
self,
decoder: StableCascadeUNet,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
text_encoder: CLIPTextModel,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
latent_dim_scale: float = 10.67,
@@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
@@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
Args:
tokenizer (`CLIPTokenizer`):
The decoder tokenizer to be used for text inputs.
text_encoder (`CLIPTextModelWithProjection`):
text_encoder (`CLIPTextModel`):
The decoder text encoder to be used for text inputs.
decoder (`StableCascadeUNet`):
The decoder model to be used for decoder image generation pipeline.
@@ -60,18 +60,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
The scheduler to be used for decoder image generation pipeline.
vqgan (`PaellaVQModel`):
The VQGAN model to be used for decoder image generation pipeline.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
prior_prior (`StableCascadeUNet`):
The prior model to be used for prior pipeline.
prior_text_encoder (`CLIPTextModelWithProjection`):
The prior text encoder to be used for text inputs.
prior_tokenizer (`CLIPTokenizer`):
The prior tokenizer to be used for text inputs.
prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline.
prior_feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
"""
_load_connected_pipes = True
@@ -80,12 +76,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
text_encoder: CLIPTextModel,
decoder: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
prior_prior: StableCascadeUNet,
prior_text_encoder: CLIPTextModelWithProjection,
prior_text_encoder: CLIPTextModel,
prior_tokenizer: CLIPTokenizer,
prior_scheduler: DDPMWuerstchenScheduler,
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
@@ -141,7 +141,7 @@ class StableUnCLIPPipeline(
image_noising_scheduler: KarrasDiffusionSchedulers,
# regular denoising components
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
text_encoder: CLIPTextModelWithProjection,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
# vae
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
@@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`SiglipVisionModel`, *optional*):
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`SiglipImageProcessor`, *optional*):
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
@@ -197,8 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: SiglipVisionModel = None,
feature_extractor: SiglipImageProcessor = None,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
@@ -197,10 +197,6 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
@@ -218,8 +214,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`SiglipVisionModel`, *optional*):
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`SiglipImageProcessor`, *optional*):
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
@@ -217,8 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -19,31 +19,15 @@ from typing import Callable, List, Optional, Union
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPTokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import (
StableDiffusionLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...schedulers import LMSDiscreteScheduler
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from ..stable_diffusion import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -111,13 +95,13 @@ class StableDiffusionKDiffusionPipeline(
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker: bool = True,
):
super().__init__()
+2 -2
View File
@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
from diffusers import DiffusionPipeline, ImagePipelineOutput
class CustomLocalPipeline(DiffusionPipeline):
@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
+1 -2
View File
@@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
import torch
from diffusers import SchedulerMixin, UNet2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -34,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@@ -91,10 +91,10 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
text_encoder = Gemma2Model(config)
components = {
"transformer": transformer,
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder": text_encoder.eval(),
"tokenizer": tokenizer,
}
return components