Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 98954fc2e1 | |||
| 1262d19d16 | |||
| 201da97dd0 | |||
| f36ba9f094 | |||
| 1c50a5f7e0 | |||
| 7ae6347e33 | |||
| 178d32dedd | |||
| ef1e628729 | |||
| 173e1b147d | |||
| e46e139f95 | |||
| 4423097b23 | |||
| 14725164be | |||
| 638cc035e5 |
@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
args,
|
||||||
instance_data_root,
|
instance_data_root,
|
||||||
instance_prompt,
|
instance_prompt,
|
||||||
class_prompt,
|
class_prompt,
|
||||||
@@ -980,10 +981,8 @@ class DreamBoothDataset(Dataset):
|
|||||||
class_num=None,
|
class_num=None,
|
||||||
size=1024,
|
size=1024,
|
||||||
repeats=1,
|
repeats=1,
|
||||||
center_crop=False,
|
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.center_crop = center_crop
|
|
||||||
|
|
||||||
self.instance_prompt = instance_prompt
|
self.instance_prompt = instance_prompt
|
||||||
self.custom_instance_prompts = None
|
self.custom_instance_prompts = None
|
||||||
@@ -1058,7 +1057,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
if interpolation is None:
|
if interpolation is None:
|
||||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||||
train_resize = transforms.Resize(size, interpolation=interpolation)
|
train_resize = transforms.Resize(size, interpolation=interpolation)
|
||||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)
|
||||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||||
train_transforms = transforms.Compose(
|
train_transforms = transforms.Compose(
|
||||||
[
|
[
|
||||||
@@ -1075,11 +1074,11 @@ class DreamBoothDataset(Dataset):
|
|||||||
# flip
|
# flip
|
||||||
image = train_flip(image)
|
image = train_flip(image)
|
||||||
if args.center_crop:
|
if args.center_crop:
|
||||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
y1 = max(0, int(round((image.height - self.size) / 2.0)))
|
||||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
x1 = max(0, int(round((image.width - self.size) / 2.0)))
|
||||||
image = train_crop(image)
|
image = train_crop(image)
|
||||||
else:
|
else:
|
||||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
|
||||||
image = crop(image, y1, x1, h, w)
|
image = crop(image, y1, x1, h, w)
|
||||||
image = train_transforms(image)
|
image = train_transforms(image)
|
||||||
self.pixel_values.append(image)
|
self.pixel_values.append(image)
|
||||||
@@ -1102,7 +1101,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
self.image_transforms = transforms.Compose(
|
self.image_transforms = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.Resize(size, interpolation=interpolation),
|
transforms.Resize(size, interpolation=interpolation),
|
||||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.5], [0.5]),
|
transforms.Normalize([0.5], [0.5]),
|
||||||
]
|
]
|
||||||
@@ -1827,6 +1826,7 @@ def main(args):
|
|||||||
|
|
||||||
# Dataset and DataLoaders creation:
|
# Dataset and DataLoaders creation:
|
||||||
train_dataset = DreamBoothDataset(
|
train_dataset = DreamBoothDataset(
|
||||||
|
args=args,
|
||||||
instance_data_root=args.instance_data_dir,
|
instance_data_root=args.instance_data_dir,
|
||||||
instance_prompt=args.instance_prompt,
|
instance_prompt=args.instance_prompt,
|
||||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||||
@@ -1836,7 +1836,6 @@ def main(args):
|
|||||||
class_num=args.num_class_images,
|
class_num=args.num_class_images,
|
||||||
size=args.resolution,
|
size=args.resolution,
|
||||||
repeats=args.repeats,
|
repeats=args.repeats,
|
||||||
center_crop=args.center_crop,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
|||||||
@@ -366,6 +366,8 @@ else:
|
|||||||
[
|
[
|
||||||
"StableDiffusionXLAutoBlocks",
|
"StableDiffusionXLAutoBlocks",
|
||||||
"StableDiffusionXLModularPipeline",
|
"StableDiffusionXLModularPipeline",
|
||||||
|
"WanAutoBlocks",
|
||||||
|
"WanModularPipeline",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["pipelines"].extend(
|
_import_structure["pipelines"].extend(
|
||||||
@@ -999,6 +1001,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .modular_pipelines import (
|
from .modular_pipelines import (
|
||||||
StableDiffusionXLAutoBlocks,
|
StableDiffusionXLAutoBlocks,
|
||||||
StableDiffusionXLModularPipeline,
|
StableDiffusionXLModularPipeline,
|
||||||
|
WanAutoBlocks,
|
||||||
|
WanModularPipeline,
|
||||||
)
|
)
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
AllegroPipeline,
|
AllegroPipeline,
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ class TransformerBlockRegistry:
|
|||||||
def _register_attention_processors_metadata():
|
def _register_attention_processors_metadata():
|
||||||
from ..models.attention_processor import AttnProcessor2_0
|
from ..models.attention_processor import AttnProcessor2_0
|
||||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||||
|
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||||
|
|
||||||
# AttnProcessor2_0
|
# AttnProcessor2_0
|
||||||
AttentionProcessorRegistry.register(
|
AttentionProcessorRegistry.register(
|
||||||
@@ -124,6 +125,14 @@ def _register_attention_processors_metadata():
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# WanAttnProcessor2_0
|
||||||
|
AttentionProcessorRegistry.register(
|
||||||
|
model_class=WanAttnProcessor2_0,
|
||||||
|
metadata=AttentionProcessorMetadata(
|
||||||
|
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _register_transformer_blocks_metadata():
|
def _register_transformer_blocks_metadata():
|
||||||
from ..models.attention import BasicTransformerBlock
|
from ..models.attention import BasicTransformerBlock
|
||||||
@@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
|
|||||||
|
|
||||||
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||||
|
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|||||||
@@ -91,10 +91,19 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
|||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||||
|
query = kwargs.get("query", None)
|
||||||
|
key = kwargs.get("key", None)
|
||||||
value = kwargs.get("value", None)
|
value = kwargs.get("value", None)
|
||||||
if value is None:
|
query = query if query is not None else args[0]
|
||||||
value = args[2]
|
key = key if key is not None else args[1]
|
||||||
return value
|
value = value if value is not None else args[2]
|
||||||
|
# If the Q sequence length does not match KV sequence length, methods like
|
||||||
|
# Perturbed Attention Guidance cannot be used (because the caller expects
|
||||||
|
# the same sequence length as Q, but if we return V here, it will not match).
|
||||||
|
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
|
||||||
|
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
|
||||||
|
if query.shape[2] == value.shape[2]:
|
||||||
|
return value
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -38,18 +38,29 @@ from ..utils import (
|
|||||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||||
|
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||||
|
_REQUIRED_FLEX_VERSION = "2.5.0"
|
||||||
|
_REQUIRED_XLA_VERSION = "2.2"
|
||||||
|
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||||
|
|
||||||
|
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||||
|
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||||
|
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
||||||
|
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
||||||
|
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
||||||
|
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
|
||||||
|
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
|
if _CAN_USE_FLASH_ATTN:
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
|
|
||||||
flash_attn_func = None
|
flash_attn_func = None
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_3_available():
|
if _CAN_USE_FLASH_ATTN_3:
|
||||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||||
else:
|
else:
|
||||||
@@ -57,7 +68,7 @@ else:
|
|||||||
flash_attn_3_varlen_func = None
|
flash_attn_3_varlen_func = None
|
||||||
|
|
||||||
|
|
||||||
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
|
if _CAN_USE_SAGE_ATTN:
|
||||||
from sageattention import (
|
from sageattention import (
|
||||||
sageattn,
|
sageattn,
|
||||||
sageattn_qk_int8_pv_fp8_cuda,
|
sageattn_qk_int8_pv_fp8_cuda,
|
||||||
@@ -67,9 +78,6 @@ if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
|
|||||||
sageattn_varlen,
|
sageattn_varlen,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
|
||||||
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
|
|
||||||
)
|
|
||||||
sageattn = None
|
sageattn = None
|
||||||
sageattn_qk_int8_pv_fp16_cuda = None
|
sageattn_qk_int8_pv_fp16_cuda = None
|
||||||
sageattn_qk_int8_pv_fp16_triton = None
|
sageattn_qk_int8_pv_fp16_triton = None
|
||||||
@@ -78,39 +86,39 @@ else:
|
|||||||
sageattn_varlen = None
|
sageattn_varlen = None
|
||||||
|
|
||||||
|
|
||||||
if is_torch_version(">=", "2.5.0"):
|
if _CAN_USE_FLEX_ATTN:
|
||||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||||
# compiled function.
|
# compiled function.
|
||||||
import torch.nn.attention.flex_attention as flex_attention
|
import torch.nn.attention.flex_attention as flex_attention
|
||||||
|
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if _CAN_USE_NPU_ATTN:
|
||||||
from torch_npu import npu_fusion_attention
|
from torch_npu import npu_fusion_attention
|
||||||
else:
|
else:
|
||||||
npu_fusion_attention = None
|
npu_fusion_attention = None
|
||||||
|
|
||||||
|
|
||||||
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
|
if _CAN_USE_XLA_ATTN:
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||||
else:
|
else:
|
||||||
xla_flash_attention = None
|
xla_flash_attention = None
|
||||||
|
|
||||||
|
|
||||||
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
|
if _CAN_USE_XFORMERS_ATTN:
|
||||||
import xformers.ops as xops
|
import xformers.ops as xops
|
||||||
else:
|
else:
|
||||||
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
|
|
||||||
xops = None
|
xops = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
# TODO(aryan): Add support for the following:
|
# TODO(aryan): Add support for the following:
|
||||||
# - Sage Attention++
|
# - Sage Attention++
|
||||||
# - block sparse, radial and other attention methods
|
# - block sparse, radial and other attention methods
|
||||||
# - CP with sage attention, flex, xformers, other missing backends
|
# - CP with sage attention, flex, xformers, other missing backends
|
||||||
# - Add support for normal and CP training with backends that don't support it yet
|
# - Add support for normal and CP training with backends that don't support it yet
|
||||||
|
|
||||||
|
|
||||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||||
@@ -179,13 +187,16 @@ class _AttentionBackendRegistry:
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
|
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
||||||
"""
|
"""
|
||||||
Context manager to set the active attention backend.
|
Context manager to set the active attention backend.
|
||||||
"""
|
"""
|
||||||
if backend not in _AttentionBackendRegistry._backends:
|
if backend not in _AttentionBackendRegistry._backends:
|
||||||
raise ValueError(f"Backend {backend} is not registered.")
|
raise ValueError(f"Backend {backend} is not registered.")
|
||||||
|
|
||||||
|
backend = AttentionBackendName(backend)
|
||||||
|
_check_attention_backend_requirements(backend)
|
||||||
|
|
||||||
old_backend = _AttentionBackendRegistry._active_backend
|
old_backend = _AttentionBackendRegistry._active_backend
|
||||||
_AttentionBackendRegistry._active_backend = backend
|
_AttentionBackendRegistry._active_backend = backend
|
||||||
|
|
||||||
@@ -226,9 +237,10 @@ def dispatch_attention_fn(
|
|||||||
"dropout_p": dropout_p,
|
"dropout_p": dropout_p,
|
||||||
"is_causal": is_causal,
|
"is_causal": is_causal,
|
||||||
"scale": scale,
|
"scale": scale,
|
||||||
"enable_gqa": enable_gqa,
|
|
||||||
**attention_kwargs,
|
**attention_kwargs,
|
||||||
}
|
}
|
||||||
|
if is_torch_version(">=", "2.5.0"):
|
||||||
|
kwargs["enable_gqa"] = enable_gqa
|
||||||
|
|
||||||
if _AttentionBackendRegistry._checks_enabled:
|
if _AttentionBackendRegistry._checks_enabled:
|
||||||
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
|
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
|
||||||
@@ -305,6 +317,57 @@ def _check_shape(
|
|||||||
# ===== Helper functions =====
|
# ===== Helper functions =====
|
||||||
|
|
||||||
|
|
||||||
|
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
|
||||||
|
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
|
||||||
|
if not _CAN_USE_FLASH_ATTN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
||||||
|
if not _CAN_USE_FLASH_ATTN_3:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend in [
|
||||||
|
AttentionBackendName.SAGE,
|
||||||
|
AttentionBackendName.SAGE_VARLEN,
|
||||||
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||||
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||||
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||||
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||||
|
]:
|
||||||
|
if not _CAN_USE_SAGE_ATTN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend == AttentionBackendName.FLEX:
|
||||||
|
if not _CAN_USE_FLEX_ATTN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend == AttentionBackendName._NATIVE_NPU:
|
||||||
|
if not _CAN_USE_NPU_ATTN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend == AttentionBackendName._NATIVE_XLA:
|
||||||
|
if not _CAN_USE_XLA_ATTN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend == AttentionBackendName.XFORMERS:
|
||||||
|
if not _CAN_USE_XFORMERS_ATTN:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=128)
|
@functools.lru_cache(maxsize=128)
|
||||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
|||||||
@@ -622,19 +622,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
attention as backend.
|
attention as backend.
|
||||||
"""
|
"""
|
||||||
from .attention import AttentionModuleMixin
|
from .attention import AttentionModuleMixin
|
||||||
from .attention_dispatch import AttentionBackendName
|
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
|
||||||
|
|
||||||
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
||||||
from .attention_processor import Attention, MochiAttention
|
from .attention_processor import Attention, MochiAttention
|
||||||
|
|
||||||
|
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
|
||||||
|
|
||||||
backend = backend.lower()
|
backend = backend.lower()
|
||||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||||
if backend not in available_backends:
|
if backend not in available_backends:
|
||||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||||
|
|
||||||
backend = AttentionBackendName(backend)
|
backend = AttentionBackendName(backend)
|
||||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
_check_attention_backend_requirements(backend)
|
||||||
|
|
||||||
|
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if not isinstance(module, attention_classes):
|
if not isinstance(module, attention_classes):
|
||||||
continue
|
continue
|
||||||
@@ -651,6 +653,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
from .attention import AttentionModuleMixin
|
from .attention import AttentionModuleMixin
|
||||||
from .attention_processor import Attention, MochiAttention
|
from .attention_processor import Attention, MochiAttention
|
||||||
|
|
||||||
|
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
|
||||||
|
|
||||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if not isinstance(module, attention_classes):
|
if not isinstance(module, attention_classes):
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class UNet2DConditionModel(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
|
||||||
_skip_layerwise_casting_patterns = ["norm"]
|
_skip_layerwise_casting_patterns = ["norm"]
|
||||||
_repeated_blocks = ["BasicTransformerBlock"]
|
_repeated_blocks = ["BasicTransformerBlock"]
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ else:
|
|||||||
"InsertableDict",
|
"InsertableDict",
|
||||||
]
|
]
|
||||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||||
|
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||||
_import_structure["components_manager"] = ["ComponentsManager"]
|
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||||
|
|
||||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
@@ -71,6 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionXLAutoBlocks,
|
StableDiffusionXLAutoBlocks,
|
||||||
StableDiffusionXLModularPipeline,
|
StableDiffusionXLModularPipeline,
|
||||||
)
|
)
|
||||||
|
from .wan import WanAutoBlocks, WanModularPipeline
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -386,6 +386,7 @@ class ComponentsManager:
|
|||||||
id(component) is Python's built-in unique identifier for the object
|
id(component) is Python's built-in unique identifier for the object
|
||||||
"""
|
"""
|
||||||
component_id = f"{name}_{id(component)}"
|
component_id = f"{name}_{id(component)}"
|
||||||
|
is_new_component = True
|
||||||
|
|
||||||
# check for duplicated components
|
# check for duplicated components
|
||||||
for comp_id, comp in self.components.items():
|
for comp_id, comp in self.components.items():
|
||||||
@@ -394,6 +395,7 @@ class ComponentsManager:
|
|||||||
if comp_name == name:
|
if comp_name == name:
|
||||||
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
|
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
|
||||||
component_id = comp_id
|
component_id = comp_id
|
||||||
|
is_new_component = False
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -426,7 +428,9 @@ class ComponentsManager:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
|
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
|
||||||
)
|
)
|
||||||
self.remove(comp_id)
|
# remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
|
||||||
|
self.remove_from_collection(comp_id, collection)
|
||||||
|
|
||||||
self.collections[collection].add(component_id)
|
self.collections[collection].add(component_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
|
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
|
||||||
@@ -434,11 +438,29 @@ class ComponentsManager:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
|
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
|
||||||
|
|
||||||
if self._auto_offload_enabled:
|
if self._auto_offload_enabled and is_new_component:
|
||||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||||
|
|
||||||
return component_id
|
return component_id
|
||||||
|
|
||||||
|
def remove_from_collection(self, component_id: str, collection: str):
|
||||||
|
"""
|
||||||
|
Remove a component from a collection.
|
||||||
|
"""
|
||||||
|
if collection not in self.collections:
|
||||||
|
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
||||||
|
return
|
||||||
|
if component_id not in self.collections[collection]:
|
||||||
|
logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
|
||||||
|
return
|
||||||
|
# remove from the collection
|
||||||
|
self.collections[collection].remove(component_id)
|
||||||
|
# check if this component is in any other collection
|
||||||
|
comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
|
||||||
|
if not comp_colls: # only if no other collection contains this component, remove it
|
||||||
|
logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
|
||||||
|
self.remove(component_id)
|
||||||
|
|
||||||
def remove(self, component_id: str = None):
|
def remove(self, component_id: str = None):
|
||||||
"""
|
"""
|
||||||
Remove a component from the ComponentsManager.
|
Remove a component from the ComponentsManager.
|
||||||
|
|||||||
@@ -60,12 +60,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||||
|
("wan", "WanModularPipeline"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
|
MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
|
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
|
||||||
|
("WanModularPipeline", "WanAutoBlocks"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -322,9 +324,12 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
</Tip>
|
</Tip>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_name = "config.json"
|
config_name = "modular_config.json"
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sub_blocks = InsertableDict()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_signature_keys(cls, obj):
|
def _get_signature_keys(cls, obj):
|
||||||
parameters = inspect.signature(obj.__init__).parameters
|
parameters = inspect.signature(obj.__init__).parameters
|
||||||
@@ -342,11 +347,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
def expected_configs(self) -> List[ConfigSpec]:
|
def expected_configs(self) -> List[ConfigSpec]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
|
||||||
def intermediate_inputs(self) -> List[OutputParam]:
|
|
||||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def intermediate_outputs(self) -> List[OutputParam]:
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||||
@@ -1456,11 +1456,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
"""List of input parameters. Must be implemented by subclasses."""
|
"""List of input parameters. Must be implemented by subclasses."""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
|
||||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
|
||||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||||
@@ -1474,14 +1469,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
input_names.append(input_param.name)
|
input_names.append(input_param.name)
|
||||||
return input_names
|
return input_names
|
||||||
|
|
||||||
@property
|
|
||||||
def loop_required_intermediate_inputs(self) -> List[str]:
|
|
||||||
input_names = []
|
|
||||||
for input_param in self.loop_intermediate_inputs:
|
|
||||||
if input_param.required:
|
|
||||||
input_names.append(input_param.name)
|
|
||||||
return input_names
|
|
||||||
|
|
||||||
# modified from SequentialPipelineBlocks to include loop_expected_components
|
# modified from SequentialPipelineBlocks to include loop_expected_components
|
||||||
@property
|
@property
|
||||||
def expected_components(self):
|
def expected_components(self):
|
||||||
|
|||||||
@@ -185,6 +185,8 @@ class ComponentSpec:
|
|||||||
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
||||||
segments).
|
segments).
|
||||||
"""
|
"""
|
||||||
|
if self.default_creation_method == "from_config":
|
||||||
|
return "null"
|
||||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||||
parts = ["null" if p is None else p for p in parts]
|
parts = ["null" if p is None else p for p in parts]
|
||||||
return "|".join(p for p in parts if p)
|
return "|".join(p for p in parts if p)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
|||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import (
|
||||||
PipelineBlock,
|
ModularPipelineBlocks,
|
||||||
PipelineState,
|
PipelineState,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInputStep(PipelineBlock):
|
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -394,7 +394,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -543,7 +543,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -611,7 +611,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -900,7 +900,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -981,7 +981,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1092,7 +1092,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1316,7 +1316,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1499,7 +1499,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1718,7 +1718,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -23,17 +23,14 @@ from ...image_processor import VaeImageProcessor
|
|||||||
from ...models import AutoencoderKL
|
from ...models import AutoencoderKL
|
||||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
PipelineBlock,
|
|
||||||
PipelineState,
|
|
||||||
)
|
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -157,7 +154,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
|||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import (
|
||||||
BlockState,
|
BlockState,
|
||||||
LoopSequentialPipelineBlocks,
|
LoopSequentialPipelineBlocks,
|
||||||
PipelineBlock,
|
ModularPipelineBlocks,
|
||||||
PipelineState,
|
PipelineState,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# YiYi experimenting composible denoise loop
|
# YiYi experimenting composible denoise loop
|
||||||
# loop step (1): prepare latent input for denoiser
|
# loop step (1): prepare latent input for denoiser
|
||||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def intermediate_inputs(self) -> List[str]:
|
def inputs(self) -> List[str]:
|
||||||
return [
|
return [
|
||||||
InputParam(
|
InputParam(
|
||||||
"latents",
|
"latents",
|
||||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (2): denoise the latents with guidance
|
# loop step (2): denoise the latents with guidance
|
||||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -249,7 +249,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -449,7 +449,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (3): scheduler step to update latents
|
# loop step (3): scheduler step to update latents
|
||||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -520,7 +520,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (3): scheduler step to update latents (with inpainting)
|
# loop step (3): scheduler step to update latents (with inpainting)
|
||||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -660,7 +660,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
def loop_inputs(self) -> List[InputParam]:
|
||||||
return [
|
return [
|
||||||
InputParam(
|
InputParam(
|
||||||
"timesteps",
|
"timesteps",
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from ...utils import (
|
|||||||
scale_lora_layers,
|
scale_lora_layers,
|
||||||
unscale_lora_layers,
|
unscale_lora_layers,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
|||||||
raise AttributeError("Could not access latents of provided encoder_output")
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -691,7 +691,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["encoders"] = ["WanTextEncoderStep"]
|
||||||
|
_import_structure["modular_blocks"] = [
|
||||||
|
"ALL_BLOCKS",
|
||||||
|
"AUTO_BLOCKS",
|
||||||
|
"TEXT2VIDEO_BLOCKS",
|
||||||
|
"WanAutoBeforeDenoiseStep",
|
||||||
|
"WanAutoBlocks",
|
||||||
|
"WanAutoBlocks",
|
||||||
|
"WanAutoDecodeStep",
|
||||||
|
"WanAutoDenoiseStep",
|
||||||
|
]
|
||||||
|
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .encoders import WanTextEncoderStep
|
||||||
|
from .modular_blocks import (
|
||||||
|
ALL_BLOCKS,
|
||||||
|
AUTO_BLOCKS,
|
||||||
|
TEXT2VIDEO_BLOCKS,
|
||||||
|
WanAutoBeforeDenoiseStep,
|
||||||
|
WanAutoBlocks,
|
||||||
|
WanAutoDecodeStep,
|
||||||
|
WanAutoDenoiseStep,
|
||||||
|
)
|
||||||
|
from .modular_pipeline import WanModularPipeline
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
@@ -0,0 +1,365 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...schedulers import UniPCMultistepScheduler
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import WanModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
|
||||||
|
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
|
||||||
|
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
|
||||||
|
# configuration of guider is.
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
class WanInputStep(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Input processing step that:\n"
|
||||||
|
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||||
|
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
|
||||||
|
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
|
||||||
|
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
|
||||||
|
"have a final batch_size of batch_size * num_videos_per_prompt."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_videos_per_prompt", default=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"negative_prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"batch_size",
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"dtype",
|
||||||
|
type_hint=torch.dtype,
|
||||||
|
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||||
|
description="text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"negative_prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
|
||||||
|
description="negative text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def check_inputs(self, components, block_state):
|
||||||
|
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
|
||||||
|
if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||||
|
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||||
|
f" {block_state.negative_prompt_embeds.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
|
||||||
|
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||||
|
block_state.dtype = block_state.prompt_embeds.dtype
|
||||||
|
|
||||||
|
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||||
|
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_state.negative_prompt_embeds is not None:
|
||||||
|
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||||
|
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||||
|
1, block_state.num_videos_per_prompt, 1
|
||||||
|
)
|
||||||
|
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||||
|
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that sets the scheduler's timesteps for inference"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_inference_steps", default=50),
|
||||||
|
InputParam("timesteps"),
|
||||||
|
InputParam("sigmas"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||||
|
OutputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of denoising steps to perform at inference time",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||||
|
components.scheduler,
|
||||||
|
block_state.num_inference_steps,
|
||||||
|
block_state.device,
|
||||||
|
block_state.timesteps,
|
||||||
|
block_state.sigmas,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam("num_frames", type_hint=int),
|
||||||
|
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||||
|
InputParam("num_videos_per_prompt", type_hint=int, default=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("generator"),
|
||||||
|
InputParam(
|
||||||
|
"batch_size",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
|
||||||
|
),
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(components, block_state):
|
||||||
|
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||||
|
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||||
|
)
|
||||||
|
if block_state.num_frames is not None and (
|
||||||
|
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
|
||||||
|
def prepare_latents(
|
||||||
|
comp,
|
||||||
|
batch_size: int,
|
||||||
|
num_channels_latents: int = 16,
|
||||||
|
height: int = 480,
|
||||||
|
width: int = 832,
|
||||||
|
num_frames: int = 81,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if latents is not None:
|
||||||
|
return latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
num_latent_frames,
|
||||||
|
int(height) // comp.vae_scale_factor_spatial,
|
||||||
|
int(width) // comp.vae_scale_factor_spatial,
|
||||||
|
)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.height = block_state.height or components.default_height
|
||||||
|
block_state.width = block_state.width or components.default_width
|
||||||
|
block_state.num_frames = block_state.num_frames or components.default_num_frames
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
|
||||||
|
block_state.num_channels_latents = components.num_channels_latents
|
||||||
|
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
|
||||||
|
block_state.latents = self.prepare_latents(
|
||||||
|
components,
|
||||||
|
block_state.batch_size * block_state.num_videos_per_prompt,
|
||||||
|
block_state.num_channels_latents,
|
||||||
|
block_state.height,
|
||||||
|
block_state.width,
|
||||||
|
block_state.num_frames,
|
||||||
|
block_state.dtype,
|
||||||
|
block_state.device,
|
||||||
|
block_state.generator,
|
||||||
|
block_state.latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...models import AutoencoderKLWan
|
||||||
|
from ...utils import logging
|
||||||
|
from ...video_processor import VideoProcessor
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class WanDecodeStep(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKLWan),
|
||||||
|
ComponentSpec(
|
||||||
|
"video_processor",
|
||||||
|
VideoProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that decodes the denoised latents into images"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("output_type", default="pil"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The denoised latents from the denoising step",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"videos",
|
||||||
|
type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
|
||||||
|
description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
vae_dtype = components.vae.dtype
|
||||||
|
|
||||||
|
if not block_state.output_type == "latent":
|
||||||
|
latents = block_state.latents
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(components.vae.config.latents_mean)
|
||||||
|
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||||
|
.to(latents.device, latents.dtype)
|
||||||
|
)
|
||||||
|
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||||
|
1, components.vae.config.z_dim, 1, 1, 1
|
||||||
|
).to(latents.device, latents.dtype)
|
||||||
|
latents = latents / latents_std + latents_mean
|
||||||
|
latents = latents.to(vae_dtype)
|
||||||
|
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
|
||||||
|
else:
|
||||||
|
block_state.videos = block_state.latents
|
||||||
|
|
||||||
|
block_state.videos = components.video_processor.postprocess_video(
|
||||||
|
block_state.videos, output_type=block_state.output_type
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
@@ -0,0 +1,261 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...guiders import ClassifierFreeGuidance
|
||||||
|
from ...models import WanTransformer3DModel
|
||||||
|
from ...schedulers import UniPCMultistepScheduler
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import (
|
||||||
|
BlockState,
|
||||||
|
LoopSequentialPipelineBlocks,
|
||||||
|
ModularPipelineBlocks,
|
||||||
|
PipelineState,
|
||||||
|
)
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import WanModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 5.0}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
ComponentSpec("transformer", WanTransformer3DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that denoise the latents with guidance. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("attention_kwargs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
kwargs_type="guider_input_fields",
|
||||||
|
description=(
|
||||||
|
"All conditional model inputs that need to be prepared with guider. "
|
||||||
|
"It should contain prompt_embeds/negative_prompt_embeds. "
|
||||||
|
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||||
|
) -> PipelineState:
|
||||||
|
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||||
|
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||||
|
guider_input_fields = {
|
||||||
|
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||||
|
}
|
||||||
|
transformer_dtype = components.transformer.dtype
|
||||||
|
|
||||||
|
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||||
|
|
||||||
|
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||||
|
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||||
|
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||||
|
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||||
|
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||||
|
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||||
|
|
||||||
|
# run the denoiser for each guidance batch
|
||||||
|
for guider_state_batch in guider_state:
|
||||||
|
components.guider.prepare_models(components.transformer)
|
||||||
|
cond_kwargs = guider_state_batch.as_dict()
|
||||||
|
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||||
|
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||||
|
guider_state_batch.noise_pred = components.transformer(
|
||||||
|
hidden_states=block_state.latents.to(transformer_dtype),
|
||||||
|
timestep=t.flatten(),
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
attention_kwargs=block_state.attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
components.guider.cleanup_models(components.transformer)
|
||||||
|
|
||||||
|
# Perform guidance
|
||||||
|
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"step within the denoising loop that update the latents. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam("generator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||||
|
# Perform scheduler step using the predicted output
|
||||||
|
latents_dtype = block_state.latents.dtype
|
||||||
|
block_state.latents = components.scheduler.step(
|
||||||
|
block_state.noise_pred.float(),
|
||||||
|
t,
|
||||||
|
block_state.latents.float(),
|
||||||
|
**block_state.scheduler_step_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if block_state.latents.dtype != latents_dtype:
|
||||||
|
block_state.latents = block_state.latents.to(latents_dtype)
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
||||||
|
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 5.0}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
||||||
|
ComponentSpec("transformer", WanTransformer3DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"timesteps",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.num_warmup_steps = max(
|
||||||
|
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(block_state.timesteps):
|
||||||
|
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||||
|
if i == len(block_state.timesteps) - 1 or (
|
||||||
|
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class WanDenoiseStep(WanDenoiseLoopWrapper):
|
||||||
|
block_classes = [
|
||||||
|
WanLoopDenoiser,
|
||||||
|
WanLoopAfterDenoiser,
|
||||||
|
]
|
||||||
|
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoise the latents. \n"
|
||||||
|
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||||
|
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||||
|
" - `WanLoopDenoiser`\n"
|
||||||
|
" - `WanLoopAfterDenoiser`\n"
|
||||||
|
"This block supports both text2vid tasks."
|
||||||
|
)
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import html
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...guiders import ClassifierFreeGuidance
|
||||||
|
from ...utils import is_ftfy_available, logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import WanModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_ftfy_available():
|
||||||
|
import ftfy
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def basic_clean(text):
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_clean(text):
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_clean(text):
|
||||||
|
text = whitespace_clean(basic_clean(text))
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text Encoder step that generate text_embeddings to guide the video generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("text_encoder", UMT5EncoderModel),
|
||||||
|
ComponentSpec("tokenizer", AutoTokenizer),
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 5.0}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_configs(self) -> List[ConfigSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("negative_prompt"),
|
||||||
|
InputParam("attention_kwargs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
kwargs_type="guider_input_fields",
|
||||||
|
description="text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"negative_prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
kwargs_type="guider_input_fields",
|
||||||
|
description="negative text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
if block_state.prompt is not None and (
|
||||||
|
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||||
|
):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_t5_prompt_embeds(
|
||||||
|
components,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
max_sequence_length: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
dtype = components.text_encoder.dtype
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
prompt = [prompt_clean(u) for u in prompt]
|
||||||
|
|
||||||
|
text_inputs = components.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||||
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||||
|
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||||
|
prompt_embeds = torch.stack(
|
||||||
|
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_prompt(
|
||||||
|
components,
|
||||||
|
prompt: str,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_videos_per_prompt: int = 1,
|
||||||
|
prepare_unconditional_embeds: bool = True,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
num_videos_per_prompt (`int`):
|
||||||
|
number of videos that should be generated per prompt
|
||||||
|
prepare_unconditional_embeds (`bool`):
|
||||||
|
whether to use prepare unconditional embeddings or not
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
max_sequence_length (`int`, defaults to `512`):
|
||||||
|
The maximum number of text tokens to be used for the generation process.
|
||||||
|
"""
|
||||||
|
device = device or components._execution_device
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
|
||||||
|
|
||||||
|
if prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||||
|
negative_prompt = negative_prompt or ""
|
||||||
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
|
||||||
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
|
||||||
|
components, negative_prompt, max_sequence_length, device
|
||||||
|
)
|
||||||
|
|
||||||
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
# Get inputs and intermediates
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
# Encode input prompt
|
||||||
|
(
|
||||||
|
block_state.prompt_embeds,
|
||||||
|
block_state.negative_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
components,
|
||||||
|
block_state.prompt,
|
||||||
|
block_state.device,
|
||||||
|
1,
|
||||||
|
block_state.prepare_unconditional_embeds,
|
||||||
|
block_state.negative_prompt,
|
||||||
|
prompt_embeds=None,
|
||||||
|
negative_prompt_embeds=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add outputs
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||||
|
from ..modular_pipeline_utils import InsertableDict
|
||||||
|
from .before_denoise import (
|
||||||
|
WanInputStep,
|
||||||
|
WanPrepareLatentsStep,
|
||||||
|
WanSetTimestepsStep,
|
||||||
|
)
|
||||||
|
from .decoders import WanDecodeStep
|
||||||
|
from .denoise import WanDenoiseStep
|
||||||
|
from .encoders import WanTextEncoderStep
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# before_denoise: text2vid
|
||||||
|
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
WanInputStep,
|
||||||
|
WanSetTimestepsStep,
|
||||||
|
WanPrepareLatentsStep,
|
||||||
|
]
|
||||||
|
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||||
|
+ "This is a sequential pipeline blocks:\n"
|
||||||
|
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
|
||||||
|
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||||
|
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# before_denoise: all task (text2vid,)
|
||||||
|
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
WanBeforeDenoiseStep,
|
||||||
|
]
|
||||||
|
block_names = ["text2vid"]
|
||||||
|
block_trigger_inputs = [None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||||
|
+ "This is an auto pipeline block that works for text2vid.\n"
|
||||||
|
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# denoise: text2vid
|
||||||
|
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
WanDenoiseStep,
|
||||||
|
]
|
||||||
|
block_names = ["denoise"]
|
||||||
|
block_trigger_inputs = [None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoise the latents. "
|
||||||
|
"This is a auto pipeline block that works for text2vid tasks.."
|
||||||
|
" - `WanDenoiseStep` (denoise) for text2vid tasks."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# decode: all task (text2img, img2img, inpainting)
|
||||||
|
class WanAutoDecodeStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [WanDecodeStep]
|
||||||
|
block_names = ["non-inpaint"]
|
||||||
|
block_trigger_inputs = [None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
|
||||||
|
|
||||||
|
|
||||||
|
# text2vid
|
||||||
|
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
WanTextEncoderStep,
|
||||||
|
WanAutoBeforeDenoiseStep,
|
||||||
|
WanAutoDenoiseStep,
|
||||||
|
WanAutoDecodeStep,
|
||||||
|
]
|
||||||
|
block_names = [
|
||||||
|
"text_encoder",
|
||||||
|
"before_denoise",
|
||||||
|
"denoise",
|
||||||
|
"decoder",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Auto Modular pipeline for text-to-video using Wan.\n"
|
||||||
|
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", WanTextEncoderStep),
|
||||||
|
("input", WanInputStep),
|
||||||
|
("set_timesteps", WanSetTimestepsStep),
|
||||||
|
("prepare_latents", WanPrepareLatentsStep),
|
||||||
|
("denoise", WanDenoiseStep),
|
||||||
|
("decode", WanDecodeStep),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", WanTextEncoderStep),
|
||||||
|
("before_denoise", WanAutoBeforeDenoiseStep),
|
||||||
|
("denoise", WanAutoDenoiseStep),
|
||||||
|
("decode", WanAutoDecodeStep),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ALL_BLOCKS = {
|
||||||
|
"text2video": TEXT2VIDEO_BLOCKS,
|
||||||
|
"auto": AUTO_BLOCKS,
|
||||||
|
}
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...loaders import WanLoraLoaderMixin
|
||||||
|
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class WanModularPipeline(
|
||||||
|
ModularPipeline,
|
||||||
|
StableDiffusionMixin,
|
||||||
|
WanLoraLoaderMixin,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A ModularPipeline for Wan.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental feature and is likely to change in the future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_height(self):
|
||||||
|
return self.default_sample_height * self.vae_scale_factor_spatial
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_width(self):
|
||||||
|
return self.default_sample_width * self.vae_scale_factor_spatial
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_num_frames(self):
|
||||||
|
return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_sample_height(self):
|
||||||
|
return 60
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_sample_width(self):
|
||||||
|
return 104
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_sample_num_frames(self):
|
||||||
|
return 21
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor_spatial(self):
|
||||||
|
vae_scale_factor = 8
|
||||||
|
if hasattr(self, "vae") and self.vae is not None:
|
||||||
|
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
|
||||||
|
return vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor_temporal(self):
|
||||||
|
vae_scale_factor = 4
|
||||||
|
if hasattr(self, "vae") and self.vae is not None:
|
||||||
|
vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
|
||||||
|
return vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_transformer(self):
|
||||||
|
num_channels_transformer = 16
|
||||||
|
if hasattr(self, "transformer") and self.transformer is not None:
|
||||||
|
num_channels_transformer = self.transformer.config.in_channels
|
||||||
|
return num_channels_transformer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_latents(self):
|
||||||
|
num_channels_latents = 16
|
||||||
|
if hasattr(self, "vae") and self.vae is not None:
|
||||||
|
num_channels_latents = self.vae.config.z_dim
|
||||||
|
return num_channels_latents
|
||||||
@@ -663,11 +663,11 @@ class ChromaPipeline(
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||||
usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
|||||||
@@ -725,11 +725,11 @@ class ChromaImg2ImgPipeline(
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||||
usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
strength (`float, *optional*, defaults to 0.9):
|
strength (`float, *optional*, defaults to 0.9):
|
||||||
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
|
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
|
||||||
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
|
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
|
||||||
|
|||||||
@@ -674,7 +674,8 @@ class FluxPipeline(
|
|||||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||||
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
|
||||||
|
`negative_prompt` is provided.
|
||||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
@@ -687,11 +688,11 @@ class FluxPipeline(
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
|||||||
@@ -661,11 +661,11 @@ class FluxControlPipeline(
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with prompt at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
|||||||
@@ -795,11 +795,11 @@ class FluxKontextPipeline(
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with prompt at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
|||||||
@@ -989,7 +989,8 @@ class FluxKontextInpaintPipeline(
|
|||||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||||
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
|
||||||
|
`negative_prompt` is provided.
|
||||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
@@ -1015,11 +1016,11 @@ class FluxKontextInpaintPipeline(
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
|||||||
@@ -763,11 +763,11 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||||
|
|||||||
@@ -529,15 +529,14 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
|||||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
will be used.
|
will be used.
|
||||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||||
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
|
||||||
|
`negative_prompt` is provided.
|
||||||
guidance_scale (`float`, defaults to `6.0`):
|
guidance_scale (`float`, defaults to `6.0`):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality. Note that the only available
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
|
|
||||||
conditional latent is not applied.
|
|
||||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
|||||||
@@ -643,11 +643,11 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|||||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||||
passed will be used. Must be in descending order.
|
passed will be used. Must be in descending order.
|
||||||
guidance_scale (`float`, *optional*, defaults to 4.5):
|
guidance_scale (`float`, *optional*, defaults to 4.5):
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion
|
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
|
||||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||||
the text `prompt`, usually at the expense of lower image quality.
|
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
The number of images to generate per prompt.
|
The number of images to generate per prompt.
|
||||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||||
|
|||||||
@@ -32,6 +32,36 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class WanAutoBlocks(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class WanModularPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class AllegroPipeline(metaclass=DummyObject):
|
class AllegroPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ from diffusers.utils.testing_utils import (
|
|||||||
require_torch_2,
|
require_torch_2,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_accelerator_with_training,
|
require_torch_accelerator_with_training,
|
||||||
require_torch_gpu,
|
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_version_greater,
|
require_torch_version_greater,
|
||||||
run_test_in_subprocess,
|
run_test_in_subprocess,
|
||||||
@@ -1829,8 +1828,8 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
assert msg_substring in str(err_ctx.exception)
|
assert msg_substring in str(err_ctx.exception)
|
||||||
|
|
||||||
@parameterized.expand([0, "cuda", torch.device("cuda")])
|
@parameterized.expand([0, torch_device, torch.device(torch_device)])
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_passing_non_dict_device_map_works(self, device_map):
|
def test_passing_non_dict_device_map_works(self, device_map):
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
model = self.model_class(**init_dict).eval()
|
model = self.model_class(**init_dict).eval()
|
||||||
@@ -1839,8 +1838,8 @@ class ModelTesterMixin:
|
|||||||
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
|
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
|
||||||
_ = loaded_model(**inputs_dict)
|
_ = loaded_model(**inputs_dict)
|
||||||
|
|
||||||
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
|
@parameterized.expand([("", torch_device), ("", torch.device(torch_device))])
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_passing_dict_device_map_works(self, name, device):
|
def test_passing_dict_device_map_works(self, name, device):
|
||||||
# There are other valid dict-based `device_map` values too. It's best to refer to
|
# There are other valid dict-based `device_map` values too. It's best to refer to
|
||||||
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
|
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
|
||||||
@@ -1945,10 +1944,11 @@ class ModelPushToHubTester(unittest.TestCase):
|
|||||||
delete_repo(self.repo_id, token=TOKEN)
|
delete_repo(self.repo_id, token=TOKEN)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
@require_torch_2
|
@require_torch_2
|
||||||
@is_torch_compile
|
@is_torch_compile
|
||||||
@slow
|
@slow
|
||||||
|
@require_torch_version_greater("2.7.1")
|
||||||
class TorchCompileTesterMixin:
|
class TorchCompileTesterMixin:
|
||||||
different_shapes_for_compilation = None
|
different_shapes_for_compilation = None
|
||||||
|
|
||||||
@@ -2013,7 +2013,7 @@ class TorchCompileTesterMixin:
|
|||||||
model.eval()
|
model.eval()
|
||||||
# TODO: Can test for other group offloading kwargs later if needed.
|
# TODO: Can test for other group offloading kwargs later if needed.
|
||||||
group_offload_kwargs = {
|
group_offload_kwargs = {
|
||||||
"onload_device": "cuda",
|
"onload_device": torch_device,
|
||||||
"offload_device": "cpu",
|
"offload_device": "cpu",
|
||||||
"offload_type": "block_level",
|
"offload_type": "block_level",
|
||||||
"num_blocks_per_group": 1,
|
"num_blocks_per_group": 1,
|
||||||
@@ -2047,6 +2047,7 @@ class TorchCompileTesterMixin:
|
|||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_peft_backend
|
@require_peft_backend
|
||||||
@require_peft_version_greater("0.14.0")
|
@require_peft_version_greater("0.14.0")
|
||||||
|
@require_torch_version_greater("2.7.1")
|
||||||
@is_torch_compile
|
@is_torch_compile
|
||||||
class LoraHotSwappingForModelTesterMixin:
|
class LoraHotSwappingForModelTesterMixin:
|
||||||
"""Test that hotswapping does not result in recompilation on the model directly.
|
"""Test that hotswapping does not result in recompilation on the model directly.
|
||||||
|
|||||||
@@ -358,7 +358,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
|||||||
model_class = UNet2DConditionModel
|
model_class = UNet2DConditionModel
|
||||||
main_input_name = "sample"
|
main_input_name = "sample"
|
||||||
# We override the items here because the unet under consideration is small.
|
# We override the items here because the unet under consideration is small.
|
||||||
model_split_percents = [0.5, 0.3, 0.4]
|
model_split_percents = [0.5, 0.34, 0.4]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_input(self):
|
def dummy_input(self):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, T5EncoderModel
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
|
||||||
@@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||||
from ..test_pipelines_common import (
|
from ..test_pipelines_common import PipelineTesterMixin
|
||||||
PipelineTesterMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
enable_full_determinism()
|
enable_full_determinism()
|
||||||
@@ -127,11 +124,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
inputs = self.get_dummy_inputs(device)
|
inputs = self.get_dummy_inputs(device)
|
||||||
video = pipe(**inputs).frames
|
video = pipe(**inputs).frames
|
||||||
generated_video = video[0]
|
generated_video = video[0]
|
||||||
|
|
||||||
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
||||||
expected_video = torch.randn(9, 3, 16, 16)
|
|
||||||
max_diff = np.abs(generated_video - expected_video).max()
|
# fmt: off
|
||||||
self.assertLessEqual(max_diff, 1e10)
|
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
generated_slice = generated_video.flatten()
|
||||||
|
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||||
|
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||||
|
|
||||||
@unittest.skip("Test not supported")
|
@unittest.skip("Test not supported")
|
||||||
def test_attention_slicing_forward_pass(self):
|
def test_attention_slicing_forward_pass(self):
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -147,11 +146,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
inputs = self.get_dummy_inputs(device)
|
inputs = self.get_dummy_inputs(device)
|
||||||
video = pipe(**inputs).frames
|
video = pipe(**inputs).frames
|
||||||
generated_video = video[0]
|
generated_video = video[0]
|
||||||
|
|
||||||
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
||||||
expected_video = torch.randn(9, 3, 16, 16)
|
|
||||||
max_diff = np.abs(generated_video - expected_video).max()
|
# fmt: off
|
||||||
self.assertLessEqual(max_diff, 1e10)
|
expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
generated_slice = generated_video.flatten()
|
||||||
|
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||||
|
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||||
|
|
||||||
@unittest.skip("Test not supported")
|
@unittest.skip("Test not supported")
|
||||||
def test_attention_slicing_forward_pass(self):
|
def test_attention_slicing_forward_pass(self):
|
||||||
@@ -162,7 +165,25 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
|
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
pipeline_class = WanImageToVideoPipeline
|
||||||
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
|
||||||
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||||
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
required_optional_params = frozenset(
|
||||||
|
[
|
||||||
|
"num_inference_steps",
|
||||||
|
"generator",
|
||||||
|
"latents",
|
||||||
|
"return_dict",
|
||||||
|
"callback_on_step_end",
|
||||||
|
"callback_on_step_end_tensor_inputs",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
test_xformers_attention = False
|
||||||
|
supports_dduf = False
|
||||||
|
|
||||||
def get_dummy_components(self):
|
def get_dummy_components(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
vae = AutoencoderKLWan(
|
vae = AutoencoderKLWan(
|
||||||
@@ -247,3 +268,32 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
|
|||||||
"output_type": "pt",
|
"output_type": "pt",
|
||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
video = pipe(**inputs).frames
|
||||||
|
generated_video = video[0]
|
||||||
|
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
generated_slice = generated_video.flatten()
|
||||||
|
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||||
|
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||||
|
|
||||||
|
@unittest.skip("Test not supported")
|
||||||
|
def test_attention_slicing_forward_pass(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
|
||||||
|
def test_inference_batch_single_identical(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoTokenizer, T5EncoderModel
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
@@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
inputs = self.get_dummy_inputs(device)
|
inputs = self.get_dummy_inputs(device)
|
||||||
video = pipe(**inputs).frames
|
video = pipe(**inputs).frames
|
||||||
generated_video = video[0]
|
generated_video = video[0]
|
||||||
|
|
||||||
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
|
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
|
||||||
expected_video = torch.randn(17, 3, 16, 16)
|
|
||||||
max_diff = np.abs(generated_video - expected_video).max()
|
# fmt: off
|
||||||
self.assertLessEqual(max_diff, 1e10)
|
expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
|
||||||
|
# fmt:on
|
||||||
|
|
||||||
|
generated_slice = generated_video.flatten()
|
||||||
|
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||||
|
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||||
|
|
||||||
@unittest.skip("Test not supported")
|
@unittest.skip("Test not supported")
|
||||||
def test_attention_slicing_forward_pass(self):
|
def test_attention_slicing_forward_pass(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user