Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bffa3a9754 | |||
| 1c558712e8 | |||
| 1f026ad14e | |||
| 0c7589293b | |||
| ff263947ad | |||
| 66e6a0215f | |||
| 5a47442f92 | |||
| 8f6328c4a4 | |||
| 8d45f219d0 | |||
| 0fd58c7706 | |||
| 35d703310c | |||
| b455dc94a2 | |||
| 04f9d2bf3d | |||
| bc8fd864eb |
@@ -76,6 +76,7 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -127,6 +128,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -178,6 +180,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
@@ -329,6 +329,8 @@
|
||||
title: BriaTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
title: ChromaTransformer2DModel
|
||||
- local: api/models/chronoedit_transformer_3d
|
||||
title: ChronoEditTransformer3DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
@@ -628,6 +630,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/chronoedit
|
||||
title: ChronoEdit
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# ChronoEditTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
|
||||
|
||||
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import ChronoEditTransformer3DModel
|
||||
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## ChronoEditTransformer3DModel
|
||||
|
||||
[[autodoc]] ChronoEditTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,156 @@
|
||||
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# ChronoEdit
|
||||
|
||||
[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
|
||||
|
||||
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
|
||||
|
||||
*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
|
||||
|
||||
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
|
||||
|
||||
|
||||
### Image Editing
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
Optionally, enable **temporal reasoning** for improved physical consistency:
|
||||
```py
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=29,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
enable_temporal_reasoning=True,
|
||||
num_temporal_reasoning_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
### Inference with 8-Step Distillation Lora
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
|
||||
pipe.load_lora_weights(lora_path)
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=1.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
## ChronoEditPipeline
|
||||
|
||||
[[autodoc]] ChronoEditPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChronoEditPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
|
||||
@@ -45,7 +45,7 @@ def check_size(image, height, width):
|
||||
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
||||
|
||||
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
|
||||
inner_image = inner_image.convert("RGBA")
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
||||
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
|
||||
from diffusers import (
|
||||
SanaControlNetModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import AutoencoderKL, SD3Transformer2DModel
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
StableAudioPipeline,
|
||||
StableAudioProjectionModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -202,6 +202,7 @@ else:
|
||||
"BriaTransformer2DModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
"ChronoEditTransformer3DModel",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"CogView4Transformer2DModel",
|
||||
@@ -406,6 +407,7 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
]
|
||||
@@ -436,6 +438,7 @@ else:
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
"ChronoEditPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXFunControlPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -909,6 +912,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaTransformer2DModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
ChronoEditTransformer3DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -1087,6 +1091,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
@@ -1113,6 +1118,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
ChronoEditPipeline,
|
||||
CLIPImageProjection,
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -141,6 +141,16 @@ class AutoGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -99,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -85,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -226,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -166,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
@@ -234,6 +239,51 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch_from_block_state(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
||||
length 2 is provided, the first element must be the conditional data identifier and the second element
|
||||
must be the unconditional data identifier or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
|
||||
@@ -187,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -183,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -172,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -74,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
@@ -179,6 +180,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
ChronoEditTransformer3DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
|
||||
@@ -44,11 +44,16 @@ class ContextParallelConfig:
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
|
||||
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
|
||||
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
|
||||
good interconnect bandwidth.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
@@ -79,29 +84,46 @@ class ContextParallelConfig:
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
if self.ring_degree == 1 and self.ulysses_degree == 1:
|
||||
raise ValueError(
|
||||
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
return (self.ring_degree, self.ulysses_degree)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Tuple[str, str]:
|
||||
"""Dimension names for the device mesh."""
|
||||
return ("ring", "ulysses")
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
if self.ulysses_degree * self.ring_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -119,7 +141,7 @@ class ParallelConfig:
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@@ -127,14 +149,14 @@ class ParallelConfig:
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
self._mesh = mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
self.context_parallel_config.setup(rank, world_size, device, mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
|
||||
_backends = {}
|
||||
_constraints = {}
|
||||
_supported_arg_names = {}
|
||||
_supports_context_parallel = {}
|
||||
_supports_context_parallel = set()
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
@@ -237,7 +237,9 @@ class _AttentionBackendRegistry:
|
||||
cls._backends[backend] = func
|
||||
cls._constraints[backend] = constraints or []
|
||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
||||
if supports_context_parallel:
|
||||
cls._supports_context_parallel.add(backend.value)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -251,15 +253,12 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
@classmethod
|
||||
def _is_context_parallel_enabled(
|
||||
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
|
||||
def _is_context_parallel_available(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
) -> bool:
|
||||
supports_context_parallel = backend in cls._supports_context_parallel
|
||||
is_degree_greater_than_1 = parallel_config is not None and (
|
||||
parallel_config.context_parallel_config.ring_degree > 1
|
||||
or parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
)
|
||||
return supports_context_parallel and is_degree_greater_than_1
|
||||
supports_context_parallel = backend.value in cls._supports_context_parallel
|
||||
return supports_context_parallel
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -306,14 +305,6 @@ def dispatch_attention_fn(
|
||||
backend_name = AttentionBackendName(backend)
|
||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||
|
||||
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
|
||||
backend_name, parallel_config
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
||||
f"was enabled with a world size of 1."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
|
||||
@@ -102,7 +102,7 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
qkv_mutliscales: Tuple[int, ...] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
|
||||
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
|
||||
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
latent_channels: int = 16,
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
|
||||
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
|
||||
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
|
||||
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 15,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
latent_channels: int = 12,
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
act_fn: str = "silu",
|
||||
|
||||
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
|
||||
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
|
||||
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
|
||||
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
time_embedding_mix: float = 1.0,
|
||||
learn_time_embedding: bool = False,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
cross_attention_dim: int = 1024,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
time_embedding_mix: int = 1.0,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
||||
@@ -529,14 +529,19 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
# unet configs
|
||||
sample_size: Optional[int] = 96,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -550,10 +555,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# additional controlnet configs
|
||||
time_embedding_mix: float = 1.0,
|
||||
ctrl_conditioning_channels: int = 3,
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_channel_order: str = "rgb",
|
||||
ctrl_learn_time_embedding: bool = False,
|
||||
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
):
|
||||
|
||||
@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
|
||||
raise RuntimeError(
|
||||
"torch.distributed must be available and initialized before calling `enable_parallelism`."
|
||||
)
|
||||
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = None
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
|
||||
attention_backend = processor._attention_backend
|
||||
if attention_backend is None:
|
||||
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
else:
|
||||
attention_backend = AttentionBackendName(attention_backend)
|
||||
|
||||
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
|
||||
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||
raise ValueError(
|
||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||
f"calling `enable_parallelism()`."
|
||||
)
|
||||
|
||||
# All modules use the same attention processor and backend. We don't need to
|
||||
# iterate over all modules after checking the first processor
|
||||
break
|
||||
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
config.setup(rank, world_size, device, mesh=mesh)
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -20,6 +20,7 @@ if is_torch_available():
|
||||
from .transformer_bria import BriaTransformer2DModel
|
||||
from .transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from .transformer_chroma import ChromaTransformer2DModel
|
||||
from .transformer_chronoedit import ChronoEditTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
|
||||
@@ -0,0 +1,735 @@
|
||||
# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
|
||||
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
# encoder_hidden_states is only passed for cross-attention
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
||||
query = attn.to_q(hidden_states)
|
||||
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
return query, key, value
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
|
||||
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
|
||||
if attn.fused_projections:
|
||||
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
|
||||
else:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
return key_img, value_img
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "WanAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
# 512 is the context length of the text encoder, hardcoded for now
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
||||
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, *rotary_emb)
|
||||
key = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1))
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states_img = dispatch_attention_fn(
|
||||
query,
|
||||
key_img,
|
||||
value_img,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor2_0
|
||||
class WanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
|
||||
"Please use WanAttnProcessor instead. "
|
||||
)
|
||||
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return WanAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttention
|
||||
class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = WanAttnProcessor
|
||||
_available_processors = [WanAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-5,
|
||||
dropout: float = 0.0,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def fuse_projections(self):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_qkv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
||||
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_added_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
if hasattr(self, "to_added_kv"):
|
||||
delattr(self, "to_added_kv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = FP32LayerNorm(in_features)
|
||||
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
||||
self.norm2 = FP32LayerNorm(out_features)
|
||||
if pos_embed_seq_len is not None:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
||||
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
||||
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
||||
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
timestep_seq_len: Optional[int] = None,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
|
||||
class ChronoEditRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
temporal_skip_len: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.temporal_skip_len = temporal_skip_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
if num_frames == 2:
|
||||
freqs_cos_f = freqs_cos[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
else:
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
if num_frames == 2:
|
||||
freqs_sin_f = freqs_sin[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
else:
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=None,
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if temb.ndim == 4:
|
||||
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0) + temb.float()
|
||||
).chunk(6, dim=2)
|
||||
# batch_size, seq_len, 1, inner_dim
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# modified from diffusers.models.transformers.transformer_wan.WanTransformer3DModel
|
||||
class ChronoEditTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the ChronoEdit model.
|
||||
|
||||
Args:
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
},
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
rope_temporal_skip_len: int = 8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = ChronoEditRotaryPosEmbed(
|
||||
attention_head_dim, patch_size, rope_max_seq_len, temporal_skip_len=rope_temporal_skip_len
|
||||
)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Condition embeddings
|
||||
# image_embedding_dim=1280 for I2V model
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
||||
if timestep.ndim == 2:
|
||||
ts_seq_len = timestep.shape[1]
|
||||
timestep = timestep.flatten() # batch_size * seq_len
|
||||
else:
|
||||
ts_seq_len = None
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
||||
)
|
||||
if ts_seq_len is not None:
|
||||
# batch_size, seq_len, 6, inner_dim
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
||||
else:
|
||||
# batch_size, 6, inner_dim
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
if temb.ndim == 3:
|
||||
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift = shift.squeeze(2)
|
||||
scale = scale.squeeze(2)
|
||||
else:
|
||||
# batch_size, inner_dim
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
has_image_proj: int = False,
|
||||
image_proj_dim: int = 1152,
|
||||
|
||||
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: Optional[int] = None,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (64, 64),
|
||||
rope_axes_dim: Tuple[int, ...] = (64, 64),
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -172,7 +172,6 @@ class SanaLinearAttnProcessor3_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -563,7 +564,7 @@ class WanTransformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
block_out_channels: Tuple[int, ...] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
|
||||
@@ -177,16 +177,21 @@ class UNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -486,10 +491,10 @@ class UNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
groups: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
layers_per_block: Union[int, Tuple[int]] = 3,
|
||||
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
|
||||
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 4096,
|
||||
encoder_hid_dim: int = 4096,
|
||||
):
|
||||
|
||||
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 8,
|
||||
out_channels: int = 4,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
addition_time_embed_dim: int = 256,
|
||||
projection_class_embeddings_input_dim: int = 768,
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
||||
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
|
||||
num_frames: int = 25,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
timestep_ratio_embedding_dim: int = 64,
|
||||
patch_size: int = 1,
|
||||
conditioning_dim: int = 2048,
|
||||
block_out_channels: Tuple[int] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
||||
block_out_channels: Tuple[int, ...] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int, ...] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
|
||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||
1,
|
||||
1,
|
||||
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
kernel_size=3,
|
||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||
self_attn: Union[bool, Tuple[bool]] = True,
|
||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
||||
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
|
||||
switch_level: Optional[Tuple[bool]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: int = (64,)
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
layers_per_block: int = 1
|
||||
act_fn: str = "silu"
|
||||
latent_channels: int = 4
|
||||
|
||||
@@ -45,7 +45,7 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import WanAutoBlocks, WanModularPipeline
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1441,6 +1441,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
modular_config_dict: Optional[Dict[str, Any]] = None,
|
||||
config_dict: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -1492,23 +1494,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
`_blocks_class_name` in the config dict
|
||||
"""
|
||||
if blocks is None:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs from modular_repo
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -1524,52 +1511,59 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
# try to load modular_model_index.json
|
||||
try:
|
||||
config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f"modular_model_index.json not found: {e}")
|
||||
config_dict = None
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
# all the components in modular_model_index.json are from_pretrained components
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
modular_config_dict, config_dict = self._load_pipeline_config(
|
||||
pretrained_model_name_or_path, **load_config_kwargs
|
||||
)
|
||||
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
# if modular_model_index.json is not found, try to load model_index.json
|
||||
if blocks is None:
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
elif config_dict is not None:
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
else:
|
||||
logger.debug(" loading config from model_index.json")
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
blocks_class_name = None
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
config_dict = None
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs based on model_index.json
|
||||
if config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
library, class_name = value
|
||||
component_spec_dict = {
|
||||
"repo": pretrained_model_name_or_path,
|
||||
"subfolder": name,
|
||||
"type_hint": (library, class_name),
|
||||
}
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
for name, value in modular_config_dict.items():
|
||||
# all the components in modular_model_index.json are from_pretrained components
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
# if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`)
|
||||
elif config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
library, class_name = value
|
||||
component_spec_dict = {
|
||||
"repo": pretrained_model_name_or_path,
|
||||
"subfolder": name,
|
||||
"type_hint": (library, class_name),
|
||||
}
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
|
||||
@@ -1601,6 +1595,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
return self.default_blocks_name
|
||||
|
||||
@classmethod
|
||||
def _load_pipeline_config(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
**load_config_kwargs,
|
||||
):
|
||||
try:
|
||||
# try to load modular_model_index.json
|
||||
modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
return modular_config_dict, None
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" modular_model_index.json not found in the repo: {e}")
|
||||
|
||||
try:
|
||||
logger.debug(" try to load model_index.json")
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
return None, config_dict
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
@@ -1655,42 +1678,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
try:
|
||||
# try to load modular_model_index.json
|
||||
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" modular_model_index.json not found in the repo: {e}")
|
||||
config_dict = None
|
||||
modular_config_dict, config_dict = cls._load_pipeline_config(
|
||||
pretrained_model_name_or_path, **load_config_kwargs
|
||||
)
|
||||
|
||||
if config_dict is not None:
|
||||
pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
if modular_config_dict is not None:
|
||||
pipeline_class = _get_pipeline_class(cls, config=modular_config_dict)
|
||||
elif config_dict is not None:
|
||||
from diffusers.pipelines.auto_pipeline import _get_model
|
||||
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
try:
|
||||
logger.debug(" try to load model_index.json")
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.pipelines.auto_pipeline import _get_model
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
if config_dict is not None:
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
|
||||
pipeline_class = cls
|
||||
pretrained_model_name_or_path = None
|
||||
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
|
||||
pipeline_class = cls
|
||||
pretrained_model_name_or_path = None
|
||||
|
||||
pipeline = pipeline_class(
|
||||
blocks=blocks,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
components_manager=components_manager,
|
||||
collection=collection,
|
||||
modular_config_dict=modular_config_dict,
|
||||
config_dict=config_dict,
|
||||
**kwargs,
|
||||
)
|
||||
return pipeline
|
||||
@@ -2134,7 +2148,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,16 +21,14 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["WanTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"TEXT2VIDEO_BLOCKS",
|
||||
"WanAutoBeforeDenoiseStep",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoDecodeStep",
|
||||
"WanAutoDenoiseStep",
|
||||
"WanAutoImageEncoderStep",
|
||||
"WanAutoVaeImageEncoderStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||
|
||||
@@ -41,15 +39,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
TEXT2VIDEO_BLOCKS,
|
||||
WanAutoBeforeDenoiseStep,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanAutoDecodeStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
)
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
else:
|
||||
|
||||
@@ -13,10 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import WanTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -34,6 +35,97 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
# configuration of guider is.
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
input_name: str,
|
||||
input_tensor: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_videos_per_prompt: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Repeat tensor elements to match the final batch size.
|
||||
|
||||
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt)
|
||||
by repeating each element along dimension 0.
|
||||
|
||||
The input tensor must have batch size 1 or batch_size. The function will:
|
||||
- If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times
|
||||
- If batch size equals batch_size: repeat each element num_videos_per_prompt times
|
||||
|
||||
Args:
|
||||
input_name (str): Name of the input tensor (used for error messages)
|
||||
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||
batch_size (int): The base batch size (number of prompts)
|
||||
num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt)
|
||||
|
||||
Raises:
|
||||
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||
|
||||
Examples:
|
||||
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||
batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||
[4, 3]
|
||||
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||
tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||
- shape: [4, 3]
|
||||
"""
|
||||
# make sure input is a tensor
|
||||
if not isinstance(input_tensor, torch.Tensor):
|
||||
raise ValueError(f"`{input_name}` must be a tensor")
|
||||
|
||||
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||
if input_tensor.shape[0] == 1:
|
||||
repeat_by = batch_size * num_videos_per_prompt
|
||||
elif input_tensor.shape[0] == batch_size:
|
||||
repeat_by = num_videos_per_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# expand the tensor to match the batch_size * num_videos_per_prompt
|
||||
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def calculate_dimension_from_latents(
|
||||
latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculate image dimensions from latent tensor dimensions.
|
||||
|
||||
This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by
|
||||
multiplying the latent num_frames/height/width by the VAE scale factor.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
|
||||
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
|
||||
vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension.
|
||||
Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension)
|
||||
vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension.
|
||||
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The calculated image dimensions as (height, width)
|
||||
|
||||
Raises:
|
||||
ValueError: If latents tensor doesn't have 4 or 5 dimensions
|
||||
|
||||
"""
|
||||
if latents.ndim != 5:
|
||||
raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
_, _, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
|
||||
num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1
|
||||
height = latent_height * vae_scale_factor_spatial
|
||||
width = latent_width * vae_scale_factor_spatial
|
||||
|
||||
return num_frames, height, width
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -94,7 +186,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
class WanTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -109,14 +201,15 @@ class WanInputStep(ModularPipelineBlocks):
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
InputParam("num_videos_per_prompt", default=1),
|
||||
ComponentSpec("transformer", WanTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_videos_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -141,19 +234,7 @@ class WanInputStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -194,6 +275,140 @@ class WanInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["first_frame_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||
a single string or list of strings. Defaults to ["first_frame_latents"].
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to [].
|
||||
|
||||
Examples:
|
||||
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_videos_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_frames"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate num_frames, height/width from latents
|
||||
num_frames, height, width = calculate_dimension_from_latents(
|
||||
image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial
|
||||
)
|
||||
block_state.num_frames = block_state.num_frames or num_frames
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_videos_per_prompt=block_state.num_videos_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_videos_per_prompt=block_state.num_videos_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -215,26 +430,15 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("sigmas"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
block_state.device,
|
||||
device,
|
||||
block_state.timesteps,
|
||||
block_state.sigmas,
|
||||
)
|
||||
@@ -246,10 +450,6 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
@@ -262,11 +462,6 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
InputParam("num_frames", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_videos_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -337,29 +532,106 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32 # Wan latents should be torch.float32 for best quality
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.num_frames = block_state.num_frames or components.default_num_frames
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.latents = self.prepare_latents(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_videos_per_prompt,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.num_frames,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.latents,
|
||||
batch_size=block_state.batch_size * block_state.num_videos_per_prompt,
|
||||
num_channels_latents=components.num_channels_latents,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
num_frames=block_state.num_frames,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
latents=block_state.latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.first_last_frame_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -50,12 +50,6 @@ class WanDecodeStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -80,25 +74,20 @@ class WanDecodeStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
if not block_state.output_type == "latent":
|
||||
latents = block_state.latents
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
latents = latents.to(vae_dtype)
|
||||
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
|
||||
else:
|
||||
block_state.videos = block_state.latents
|
||||
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
block_state.videos, output_type=block_state.output_type
|
||||
latents = block_state.latents
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
latents = latents.to(vae_dtype)
|
||||
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,16 +27,156 @@ from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
|
||||
block_state.dtype
|
||||
)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_last_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat(
|
||||
[block_state.latents, block_state.first_last_frame_latents], dim=1
|
||||
).to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
|
||||
):
|
||||
"""Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1.
|
||||
|
||||
Args:
|
||||
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
|
||||
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
|
||||
|
||||
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
|
||||
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
|
||||
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
|
||||
'encoder_hidden_states'.
|
||||
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
|
||||
`block_state.image_embeds` for both conditional and unconditional batches.
|
||||
"""
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
@@ -59,49 +199,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
inputs = [
|
||||
InputParam("attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds. "
|
||||
"Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
}
|
||||
transformer_dtype = components.transformer.dtype
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
@@ -112,22 +233,26 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latents.to(transformer_dtype),
|
||||
timestep=t.flatten(),
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
hidden_states=block_state.latent_model_input.to(block_state.dtype),
|
||||
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
@@ -137,6 +262,141 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Wan22LoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
|
||||
):
|
||||
"""Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2.
|
||||
|
||||
Args:
|
||||
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
|
||||
(for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either:
|
||||
|
||||
- A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds",
|
||||
"negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and
|
||||
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
|
||||
`encoder_hidden_states`.
|
||||
- A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward
|
||||
`block_state.image_embeds` for both conditional and unconditional batches.
|
||||
"""
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec(
|
||||
"guider_2",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 3.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", WanTransformer3DModel),
|
||||
ComponentSpec("transformer_2", WanTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoise the latents with guidance. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(
|
||||
name="boundary_ratio",
|
||||
default=0.875,
|
||||
description="The boundary ratio to divide the denoising loop into high noise and low noise stages.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
inputs = [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps
|
||||
if t >= boundary_timestep:
|
||||
block_state.current_model = components.transformer
|
||||
block_state.guider = components.guider
|
||||
else:
|
||||
block_state.current_model = components.transformer_2
|
||||
block_state.guider = components.guider_2
|
||||
|
||||
block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
block_state.guider.prepare_models(block_state.current_model)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = block_state.current_model(
|
||||
hidden_states=block_state.latent_model_input.to(block_state.dtype),
|
||||
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
block_state.guider.cleanup_models(block_state.current_model)
|
||||
|
||||
# Perform guidance
|
||||
block_state.noise_pred = block_state.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -154,20 +414,6 @@ class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# Perform scheduler step using the predicted output
|
||||
@@ -198,18 +444,11 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 5.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
||||
ComponentSpec("transformer", WanTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -248,7 +487,12 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
|
||||
class WanDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanLoopDenoiser,
|
||||
WanLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
@@ -259,7 +503,110 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports both text2vid tasks."
|
||||
"This block supports text-to-video tasks for wan2.1."
|
||||
)
|
||||
|
||||
|
||||
class Wan22DenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanLoopBeforeDenoiser,
|
||||
Wan22LoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanLoopBeforeDenoiser`\n"
|
||||
" - `Wan22LoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports text-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanImage2VideoLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanImage2VideoLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for wan2.1."
|
||||
)
|
||||
|
||||
|
||||
class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanImage2VideoLoopBeforeDenoiser,
|
||||
Wan22LoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanImage2VideoLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanFLF2VLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanFLF2VLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports FLF2V tasks for wan2.1."
|
||||
)
|
||||
|
||||
@@ -15,21 +15,29 @@
|
||||
import html
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import is_ftfy_available, is_torchvision_available, logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -51,6 +59,103 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
text_encoder: UMT5EncoderModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
):
|
||||
dtype = text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def encode_image(
|
||||
image: PipelineImageInput,
|
||||
image_processor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModel,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
image = image_processor(images=image, return_tensors="pt").to(device)
|
||||
image_embeds = image_encoder(**image, output_hidden_states=True)
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def encode_vae_image(
|
||||
video_tensor: torch.Tensor,
|
||||
vae: AutoencoderKLWan,
|
||||
generator: torch.Generator,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
latent_channels: int = 16,
|
||||
):
|
||||
if not isinstance(video_tensor, torch.Tensor):
|
||||
raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.")
|
||||
|
||||
if isinstance(generator, list) and len(generator) != video_tensor.shape[0]:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}."
|
||||
)
|
||||
|
||||
video_tensor = video_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
video_latents = [
|
||||
retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
||||
for i in range(video_tensor.shape[0])
|
||||
]
|
||||
video_latents = torch.cat(video_latents, dim=0)
|
||||
else:
|
||||
video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax")
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, latent_channels, 1, 1, 1)
|
||||
.to(video_latents.device, video_latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to(
|
||||
video_latents.device, video_latents.dtype
|
||||
)
|
||||
video_latents = (video_latents - latents_mean) * latents_std
|
||||
|
||||
return video_latents
|
||||
|
||||
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -71,16 +176,12 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -107,47 +208,13 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def _get_t5_prompt_embeds(
|
||||
components,
|
||||
prompt: Union[str, List[str]],
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
):
|
||||
dtype = components.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
|
||||
text_inputs = components.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str,
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
@@ -158,32 +225,29 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_videos_per_prompt (`int`):
|
||||
number of videos that should be generated per prompt
|
||||
prepare_unconditional_embeds (`bool`):
|
||||
whether to use prepare unconditional embeddings or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of text tokens to be used for the generation process.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
|
||||
prompt_embeds = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
@@ -199,18 +263,14 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
|
||||
components, negative_prompt, max_sequence_length, device
|
||||
negative_prompt_embeds = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -219,7 +279,6 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# Encode input prompt
|
||||
@@ -227,16 +286,382 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds,
|
||||
block_state.negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
components,
|
||||
block_state.prompt,
|
||||
block_state.device,
|
||||
1,
|
||||
block_state.prepare_unconditional_embeds,
|
||||
block_state.negative_prompt,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
components=components,
|
||||
prompt=block_state.prompt,
|
||||
device=block_state.device,
|
||||
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
)
|
||||
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageResizeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height", type_hint=int, default=480),
|
||||
InputParam("width", type_hint=int, default=832),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("resized_image", type_hint=PIL.Image.Image),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
max_area = block_state.height * block_state.width
|
||||
|
||||
image = block_state.image
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial
|
||||
block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
block_state.resized_image = image.resize((block_state.width, block_state.height))
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageCropResizeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Resize step that resize the last_image to the same size of first frame image with center crop."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"
|
||||
),
|
||||
InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("resized_last_image", type_hint=PIL.Image.Image),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
height = block_state.resized_image.height
|
||||
width = block_state.resized_image.width
|
||||
image = block_state.last_image
|
||||
|
||||
# Calculate resize ratio to match first frame dimensions
|
||||
resize_ratio = max(width / image.width, height / image.height)
|
||||
|
||||
# Resize the image
|
||||
width = round(image.width * resize_ratio)
|
||||
height = round(image.height * resize_ratio)
|
||||
size = [width, height]
|
||||
resized_image = transforms.functional.center_crop(image, size)
|
||||
block_state.resized_last_image = resized_image
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_processor", CLIPImageProcessor),
|
||||
ComponentSpec("image_encoder", CLIPVisionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
image = block_state.resized_image
|
||||
|
||||
image_embeds = encode_image(
|
||||
image_processor=components.image_processor,
|
||||
image_encoder=components.image_encoder,
|
||||
image=image,
|
||||
device=device,
|
||||
)
|
||||
block_state.image_embeds = image_embeds
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_processor", CLIPImageProcessor),
|
||||
ComponentSpec("image_encoder", CLIPVisionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
first_frame_image = block_state.resized_image
|
||||
last_frame_image = block_state.resized_last_image
|
||||
|
||||
image_embeds = encode_image(
|
||||
image_processor=components.image_processor,
|
||||
image_encoder=components.image_encoder,
|
||||
image=[first_frame_image, last_frame_image],
|
||||
device=device,
|
||||
)
|
||||
block_state.image_embeds = image_embeds
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"first_frame_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="video latent representation with the first frame image condition",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
if block_state.num_frames is not None and (
|
||||
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
|
||||
)
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
image = block_state.resized_image
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
num_frames = block_state.num_frames or components.default_num_frames
|
||||
|
||||
image_tensor = components.video_processor.preprocess(image, height=height, width=width).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
if image_tensor.dim() == 4:
|
||||
image_tensor = image_tensor.unsqueeze(2)
|
||||
|
||||
video_tensor = torch.cat(
|
||||
[
|
||||
image_tensor,
|
||||
image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width),
|
||||
],
|
||||
dim=2,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
block_state.first_frame_latents = encode_vae_image(
|
||||
video_tensor=video_tensor,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"first_last_frame_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="video latent representation with the first and last frame images condition",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
if block_state.num_frames is not None and (
|
||||
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
|
||||
)
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
first_frame_image = block_state.resized_image
|
||||
last_frame_image = block_state.resized_last_image
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
num_frames = block_state.num_frames or components.default_num_frames
|
||||
|
||||
first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
first_image_tensor = first_image_tensor.unsqueeze(2)
|
||||
|
||||
last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
last_image_tensor = last_image_tensor.unsqueeze(2)
|
||||
|
||||
video_tensor = torch.cat(
|
||||
[
|
||||
first_image_tensor,
|
||||
first_image_tensor.new_zeros(
|
||||
first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width
|
||||
),
|
||||
last_image_tensor,
|
||||
],
|
||||
dim=2,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
block_state.first_last_frame_latents = encode_vae_image(
|
||||
video_tensor=video_tensor,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -16,96 +16,244 @@ from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
WanInputStep,
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
WanDenoiseStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeImageEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeImageEncoderStep,
|
||||
)
|
||||
from .decoders import WanDecodeStep
|
||||
from .denoise import WanDenoiseStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# before_denoise: text2vid
|
||||
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
# wan2.1
|
||||
# wan2.1: text2vid
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanInputStep,
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2vid,)
|
||||
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["text2vid"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2vid.\n"
|
||||
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
|
||||
)
|
||||
|
||||
|
||||
# denoise: text2vid
|
||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: image2video
|
||||
## image encoder
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: FLF2v
|
||||
|
||||
|
||||
## image encoder
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_last_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: auto blocks
|
||||
## image encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## denoise
|
||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanFLF2VCoreDenoiseStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["flf2v", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2vid tasks.."
|
||||
" - `WanDenoiseStep` (denoise) for text2vid tasks."
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
# decode: all task (text2img, img2img, inpainting)
|
||||
class WanAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [WanDecodeStep]
|
||||
block_names = ["non-inpaint"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
|
||||
|
||||
|
||||
# text2vid
|
||||
# auto pipeline blocks
|
||||
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoBeforeDenoiseStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanAutoDecodeStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"before_denoise",
|
||||
"image_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decoder",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -116,29 +264,211 @@ class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# wan22
|
||||
# wan2.2: text2vid
|
||||
|
||||
|
||||
## denoise
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.2: image2video
|
||||
## denoise
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
Wan22AutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# presets for wan2.1 and wan2.2
|
||||
# YiYi Notes: should we move these to doc?
|
||||
# wan2.1
|
||||
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanInputStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", WanDenoiseStep),
|
||||
("decode", WanDecodeStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
|
||||
("denoise", WanImage2VideoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLF2V_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
|
||||
("denoise", WanFLF2VDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("before_denoise", WanAutoBeforeDenoiseStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanAutoDecodeStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# wan2.2 presets
|
||||
|
||||
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# presets all blocks (wan and wan22)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"wan2.1": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS,
|
||||
"flf2v": FLF2V_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
},
|
||||
"wan2.2": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
|
||||
"auto": AUTO_BLOCKS_WAN22,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...utils import logging
|
||||
@@ -35,6 +37,13 @@ class WanModularPipeline(
|
||||
|
||||
default_blocks_name = "WanAutoBlocks"
|
||||
|
||||
# override the default_blocks_name in base class, which is just return self.default_blocks_name
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22AutoBlocks"
|
||||
else:
|
||||
return "WanAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_height * self.vae_scale_factor_spatial
|
||||
@@ -59,6 +68,13 @@ class WanModularPipeline(
|
||||
def default_sample_num_frames(self):
|
||||
return 21
|
||||
|
||||
@property
|
||||
def patch_size_spatial(self):
|
||||
patch_size_spatial = 2
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
patch_size_spatial = self.transformer.config.patch_size[1]
|
||||
return patch_size_spatial
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 8
|
||||
@@ -86,3 +102,19 @@ class WanModularPipeline(
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
num_channels_latents = self.vae.config.z_dim
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
@property
|
||||
def num_train_timesteps(self):
|
||||
num_train_timesteps = 1000
|
||||
if hasattr(self, "scheduler") and self.scheduler is not None:
|
||||
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
return num_train_timesteps
|
||||
|
||||
@@ -404,6 +404,7 @@ else:
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
]
|
||||
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -566,6 +567,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .chronoedit import ChronoEditPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
|
||||
@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
out_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
@@ -117,6 +117,7 @@ from .stable_diffusion_xl import (
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
|
||||
|
||||
@@ -214,6 +215,24 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanImageToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanVideoToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("kandinsky", KandinskyPipeline),
|
||||
@@ -247,6 +266,9 @@ SUPPORTED_TASKS_MAPPINGS = [
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_INPAINT_PIPELINES_MAPPING,
|
||||
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
|
||||
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
|
||||
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
|
||||
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
|
||||
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_chronoedit"] = ["ChronoEditPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_chronoedit import ChronoEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,752 @@
|
||||
# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import PIL
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, ChronoEditTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import ChronoEditPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
>>> from transformers import CLIPVisionModel
|
||||
|
||||
>>> # Available models: nvidia/ChronoEdit-14B-Diffusers
|
||||
>>> model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
>>> image_encoder = CLIPVisionModel.from_pretrained(
|
||||
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
... )
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> transformer = ChronoEditTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = ChronoEditPipeline.from_pretrained(
|
||||
... model_id, vae=vae, image_encoder=image_encoder, transformer=transformer, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image("https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png")
|
||||
>>> max_area = 720 * 1280
|
||||
>>> aspect_ratio = image.height / image.width
|
||||
>>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
>>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
>>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
>>> image = image.resize((width, height))
|
||||
>>> prompt = (
|
||||
... "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
... "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
... )
|
||||
|
||||
>>> output = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... height=height,
|
||||
... width=width,
|
||||
... num_frames=5,
|
||||
... guidance_scale=5.0,
|
||||
... enable_temporal_reasoning=False,
|
||||
... num_temporal_reasoning_steps=0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class ChronoEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
image_encoder ([`CLIPVisionModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
|
||||
the
|
||||
[clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
|
||||
variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
transformer: ChronoEditTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
image_encoder=image_encoder,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.image_processor = image_processor
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
|
||||
def encode_image(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
image = self.image_processor(images=image, return_tensors="pt").to(device)
|
||||
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if image is not None and image_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
if image is None and image_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
|
||||
)
|
||||
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
# modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latent_condition = latent_condition.to(dtype)
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(latent_condition.device)
|
||||
|
||||
return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
enable_temporal_reasoning: bool = False,
|
||||
num_temporal_reasoning_steps: int = 0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height of the generated video.
|
||||
width (`int`, defaults to `832`):
|
||||
The width of the generated video.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `negative_prompt` input argument.
|
||||
image_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
|
||||
image embeddings are generated from the `image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`ChronoEditPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
enable_temporal_reasoning (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable temporal reasoning.
|
||||
num_temporal_reasoning_steps (`int`, *optional*, defaults to `0`):
|
||||
The number of steps to enable temporal reasoning.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~ChronoEditPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`ChronoEditPipelineOutput`] is returned, otherwise a `tuple` is returned
|
||||
where the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
num_frames = 5 if not enable_temporal_reasoning else num_frames
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Encode image embedding
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if image_embeds is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.z_dim
|
||||
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
|
||||
latents, condition = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if enable_temporal_reasoning and i == num_temporal_reasoning_steps:
|
||||
latents = latents[:, :, [0, -1]]
|
||||
condition = condition[:, :, [0, -1]]
|
||||
|
||||
for j in range(len(self.scheduler.model_outputs)):
|
||||
if self.scheduler.model_outputs[j] is not None:
|
||||
if latents.shape[-3] != self.scheduler.model_outputs[j].shape[-3]:
|
||||
self.scheduler.model_outputs[j] = self.scheduler.model_outputs[j][:, :, [0, -1]]
|
||||
if self.scheduler.last_sample is not None:
|
||||
self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]]
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
if enable_temporal_reasoning and latents.shape[2] > 2:
|
||||
video_edit = self.vae.decode(latents[:, :, [0, -1]], return_dict=False)[0]
|
||||
video_reason = self.vae.decode(latents[:, :, :-1], return_dict=False)[0]
|
||||
video = torch.cat([video_reason, video_edit[:, :, 1:]], dim=2)
|
||||
else:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return ChronoEditPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronoEditPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for ChronoEdit pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockFlat",
|
||||
"CrossAttnDownBlockFlat",
|
||||
"CrossAttnDownBlockFlat",
|
||||
"DownBlockFlat",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
param_names: Tuple[str, ...] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
param_names: Tuple[str, ...] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int]] = (
|
||||
param_shapes: Tuple[Tuple[int, int], ...] = (
|
||||
(256, 93),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
background: Tuple[float] = (
|
||||
background: Tuple[float, ...] = (
|
||||
255.0,
|
||||
255.0,
|
||||
255.0,
|
||||
|
||||
@@ -648,6 +648,21 @@ class ChromaTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ChronoEditTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -182,6 +182,21 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22AutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -542,6 +557,21 @@ class ChromaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChronoEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CLIPImageProjection(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -32,6 +32,20 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
|
||||
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
|
||||
"Model parameters don't match!"
|
||||
)
|
||||
assert all(
|
||||
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
|
||||
), "Model parameters don't match!"
|
||||
|
||||
# Remove a shard file
|
||||
cached_shard_file = try_to_load_from_cache(
|
||||
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
# Verify error mentions the missing shard
|
||||
error_msg = str(context.exception)
|
||||
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
|
||||
f"Expected error about missing shard, got: {error_msg}"
|
||||
)
|
||||
assert (
|
||||
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
|
||||
), f"Expected error about missing shard, got: {error_msg}"
|
||||
|
||||
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 3, (
|
||||
"3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
)
|
||||
assert (
|
||||
download_requests.count("HEAD") == 3
|
||||
), "3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
)
|
||||
assert (
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .common import ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
]
|
||||
@@ -0,0 +1,180 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import is_attention, require_accelerator, torch_device
|
||||
|
||||
|
||||
@is_attention
|
||||
@require_accelerator
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
|
||||
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
base_precision = 1e-3
|
||||
|
||||
def test_fuse_unfuse_qkv_projections(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
# Get output before fusion
|
||||
with torch.no_grad():
|
||||
output_before_fusion = model(**inputs_dict)
|
||||
if isinstance(output_before_fusion, dict):
|
||||
output_before_fusion = output_before_fusion.to_tuple()[0]
|
||||
|
||||
# Fuse projections
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
# Verify fusion occurred by checking for fused attributes
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||
has_fused_projections = True
|
||||
assert module.fused_projections, "fused_projections flag should be True"
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
# Get output after fusion
|
||||
with torch.no_grad():
|
||||
output_after_fusion = model(**inputs_dict)
|
||||
if isinstance(output_after_fusion, dict):
|
||||
output_after_fusion = output_after_fusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs match
|
||||
assert torch.allclose(
|
||||
output_before_fusion, output_after_fusion, atol=self.base_precision
|
||||
), "Output should not change after fusing projections"
|
||||
|
||||
# Unfuse projections
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
# Verify unfusion occurred
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
# Get output after unfusion
|
||||
with torch.no_grad():
|
||||
output_after_unfusion = model(**inputs_dict)
|
||||
if isinstance(output_after_unfusion, dict):
|
||||
output_after_unfusion = output_after_unfusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs still match
|
||||
assert torch.allclose(
|
||||
output_before_fusion, output_after_unfusion, atol=self.base_precision
|
||||
), "Output should match original after unfusing projections"
|
||||
|
||||
def test_get_set_processor(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# Check if model has attention processors
|
||||
if not hasattr(model, "attn_processors"):
|
||||
pytest.skip("Model does not have attention processors.")
|
||||
|
||||
# Test getting processors
|
||||
processors = model.attn_processors
|
||||
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||
|
||||
# Test that all processors can be retrieved via get_processor
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
processor = module.get_processor()
|
||||
assert processor is not None, "get_processor should return a processor"
|
||||
|
||||
# Test setting a new processor
|
||||
new_processor = AttnProcessor()
|
||||
module.set_processor(new_processor)
|
||||
retrieved_processor = module.get_processor()
|
||||
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||
|
||||
def test_attention_processor_dict(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict of new processors
|
||||
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||
|
||||
# Set processors using dict
|
||||
model.set_attn_processor(new_processors)
|
||||
|
||||
# Verify all processors were set
|
||||
updated_processors = model.attn_processors
|
||||
for key in current_processors.keys():
|
||||
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||
|
||||
def test_attention_processor_count_mismatch_raises_error(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict with wrong number of processors
|
||||
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
@@ -0,0 +1,514 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def calculate_expected_num_shards(index_map_path):
|
||||
"""
|
||||
Calculate expected number of shards from index file.
|
||||
|
||||
Args:
|
||||
index_map_path: Path to the sharded checkpoint index file
|
||||
|
||||
Returns:
|
||||
int: Expected number of shards
|
||||
"""
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_device_map_is_respected(model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(
|
||||
param_device
|
||||
), f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
model_class = None
|
||||
base_precision = 1e-3
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
def get_init_dict(self):
|
||||
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
raise NotImplementedError(
|
||||
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
|
||||
)
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert (
|
||||
param_1.shape == param_2.shape
|
||||
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if (
|
||||
hasattr(self.model_class, "_keep_in_fp32_modules")
|
||||
and self.model_class._keep_in_fp32_modules is None
|
||||
):
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
||||
)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**self.get_dummy_inputs())
|
||||
if isinstance(first, dict):
|
||||
first = first.to_tuple()[0]
|
||||
|
||||
second = model(**self.get_dummy_inputs())
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
# Remove NaN values and compute max difference
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
|
||||
# Filter out NaN values
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
max_diff = torch.abs(first_filtered - second_filtered).max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert (
|
||||
output.shape == expected_output_shape
|
||||
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
), (
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_model_config_to_json_string(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
|
||||
json_string = model.config.to_json_string()
|
||||
assert isinstance(json_string, str), "Config to_json_string should return a string"
|
||||
assert len(json_string) > 0, "JSON string should not be empty"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(torch_device not in ["cuda", "xpu"])
|
||||
def test_from_save_pretrained_float16_bfloat16(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for torch_dtype in [torch.bfloat16, torch.float16]:
|
||||
model.to(torch_dtype).save_pretrained(tmp_dir)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == torch_dtype
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**get_dummy_inputs())
|
||||
output_loaded = model_loaded(**get_dummy_inputs())
|
||||
|
||||
assert torch.allclose(
|
||||
output, output_loaded, atol=1e-4
|
||||
), f"Loaded model output differs for {torch_dtype}"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints(self):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match after sharded save/load"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(
|
||||
os.path.join(tmp_dir, index_filename)
|
||||
), f"Variant index file {index_filename} should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match after variant sharded save/load"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_parallel_loading(self):
|
||||
import time
|
||||
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
# Save original values to restore after test
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
start_time = time.time()
|
||||
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
sequential_load_time = time.time() - start_time
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
start_time = time.time()
|
||||
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
parallel_load_time = time.time() - start_time
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], output_parallel[0], atol=1e-5
|
||||
), "Output should match with parallel loading"
|
||||
|
||||
# Verify parallel loading is faster or at least not significantly slower
|
||||
# For small test models, the difference might be negligible or even slightly slower due to overhead
|
||||
# so we just check that parallel loading completed successfully and outputs match
|
||||
assert (
|
||||
parallel_load_time < sequential_load_time
|
||||
), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with model parallelism"
|
||||
@@ -0,0 +1,162 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: compile
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
@@ -0,0 +1,109 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
|
||||
from ...others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTesterMixin:
|
||||
"""
|
||||
Mixin class for testing push_to_hub functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
"""
|
||||
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-model-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def test_push_to_hub(self):
|
||||
"""Test pushing model to hub and loading it back."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
"""Test pushing model to hub in organization namespace."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_library_name(self):
|
||||
"""Test that library_name in model card is set to 'diffusers'."""
|
||||
if not is_jinja_available():
|
||||
pytest.skip("Model card tests cannot be performed without Jinja installed.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
|
||||
assert (
|
||||
model_card.library_name == "diffusers"
|
||||
), f"Expected library_name 'diffusers', got {model_card.library_name}"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
@@ -0,0 +1,205 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention_processor import IPAdapterAttnProcessor
|
||||
|
||||
from ...testing_utils import is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
"""
|
||||
Create a dummy IP Adapter state dict for testing.
|
||||
|
||||
Args:
|
||||
model: The model to create IP adapter weights for
|
||||
|
||||
Returns:
|
||||
dict: IP adapter state dict with to_k_ip and to_v_ip weights
|
||||
"""
|
||||
ip_state_dict = {}
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
# Skip self-attention processors
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", None)
|
||||
if cross_attention_dim is None:
|
||||
continue
|
||||
|
||||
# Get hidden size based on model architecture
|
||||
hidden_size = getattr(model.config, "hidden_size", cross_attention_dim)
|
||||
|
||||
# Create IP adapter processor to get state dict structure
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
|
||||
ip_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
key_id += 2
|
||||
|
||||
return {"ip_adapter": ip_state_dict}
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if IP Adapter processors are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if IP Adapter is correctly set, False otherwise
|
||||
"""
|
||||
for module in model.attn_processors.values():
|
||||
if isinstance(module, IPAdapterAttnProcessor):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_ip_adapter
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: ip_adapter
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
def test_load_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create dummy IP adapter state dict
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
|
||||
# Load IP adapter
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Create dummy image embeds for IP adapter
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
inputs_dict_with_adapter = inputs_dict.copy()
|
||||
inputs_dict_with_adapter["image_embeds"] = image_embeds
|
||||
|
||||
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with IP Adapter enabled"
|
||||
|
||||
def test_ip_adapter_scale(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create and load dummy IP adapter state dict
|
||||
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
# Test scale = 0.0 (no effect)
|
||||
model.set_ip_adapter_scale(0.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Test scale = 1.0 (full effect)
|
||||
model.set_ip_adapter_scale(1.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should differ with different scales
|
||||
assert not torch.allclose(
|
||||
output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with different IP Adapter scales"
|
||||
|
||||
def test_unload_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Save original processors
|
||||
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
|
||||
|
||||
# Unload IP adapter
|
||||
model.unload_ip_adapter()
|
||||
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
|
||||
|
||||
# Verify processors are restored
|
||||
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||
|
||||
def test_ip_adapter_save_load(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict()
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_before_save = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save the IP adapter weights
|
||||
save_path = os.path.join(tmpdir, "ip_adapter.safetensors")
|
||||
import safetensors.torch
|
||||
|
||||
safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path)
|
||||
|
||||
# Unload and reload
|
||||
model.unload_ip_adapter()
|
||||
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
|
||||
|
||||
# Reload from saved file
|
||||
loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)}
|
||||
model._load_ip_adapter_weights([loaded_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded"
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should match before and after save/load
|
||||
assert torch.allclose(
|
||||
output_before_save, output_after_load, atol=1e-4, rtol=1e-4
|
||||
), "Output should match before and after save/load"
|
||||
@@ -0,0 +1,220 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import is_lora, require_peft_backend, torch_device
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if LoRA layers are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if LoRA is correctly set, False otherwise
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_lora
|
||||
@require_peft_backend
|
||||
class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: lora
|
||||
Use `pytest -m "not lora"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with LoRA enabled"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
assert os.path.isfile(
|
||||
os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
), "LoRA weights file not created"
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with LoRA enabled"
|
||||
assert torch.allclose(
|
||||
outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
|
||||
), "Outputs should match before and after save/load"
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
@@ -0,0 +1,443 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import compute_module_sizes
|
||||
|
||||
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
is_cpu_offload,
|
||||
is_group_offload,
|
||||
is_memory,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import check_device_map_is_respected
|
||||
|
||||
|
||||
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
|
||||
"""Helper to cast tensor inputs from one dtype to another."""
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
|
||||
inputs_dict[key] = value.to(to_dtype)
|
||||
return inputs_dict
|
||||
|
||||
|
||||
def require_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_group_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support group offloading.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@is_cpu_offload
|
||||
class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cpu_offload
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
def test_cpu_offload(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with CPU offloading"
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with disk offloading (safetensors)"
|
||||
|
||||
|
||||
@is_group_offload
|
||||
class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: group_offload
|
||||
Use `pytest -m "not group_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
def test_group_offloading(self, record_stream=False):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
), "Group offloading hook should be set"
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading1, atol=1e-5
|
||||
), "Output should match with block-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading2, atol=1e-5
|
||||
), "Output should match with non-blocking block-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading3, atol=1e-5
|
||||
), "Output should match with leaf-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading4, atol=1e-5
|
||||
), "Output should match with leaf-level offloading with stream"
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
return "generator" in params
|
||||
|
||||
def _run_forward(model, inputs_dict):
|
||||
accepts_generator = _has_generator_arg(model)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading, atol=atol
|
||||
), "Output should match with disk-based group offloading"
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
|
||||
def reset_memory_stats():
|
||||
gc.collect()
|
||||
backend_synchronize(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
def get_memory_usage(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
reset_memory_stats()
|
||||
model(**inputs_dict)
|
||||
model_memory_footprint = model.get_memory_footprint()
|
||||
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||
|
||||
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||
|
||||
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||
torch.float8_e4m3fn, torch.bfloat16
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
assert (
|
||||
fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
|
||||
), "Memory footprint should decrease with lower precision storage"
|
||||
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
assert (
|
||||
fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
|
||||
), "Peak memory should be lower with bf16 compute on newer GPUs"
|
||||
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||
|
||||
|
||||
@is_memory
|
||||
@require_accelerator
|
||||
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||
"""
|
||||
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||
group offloading, and layerwise dtype casting.
|
||||
|
||||
This mixin inherits from:
|
||||
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,833 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes,
|
||||
is_gguf,
|
||||
is_modelopt,
|
||||
is_quanto,
|
||||
is_torchao,
|
||||
nightly,
|
||||
require_accelerate,
|
||||
require_accelerator,
|
||||
require_bitsandbytes_version_greater,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_quanto,
|
||||
require_torchao_version_greater_or_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_nvidia_modelopt_available():
|
||||
import modelopt.torch.quantization as mtq
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
|
||||
@require_accelerator
|
||||
class QuantizationTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for quantization testing.
|
||||
|
||||
Backend-specific mixins should:
|
||||
1. Implement _create_quantized_model(config_kwargs)
|
||||
2. Implement _verify_if_layer_quantized(name, module, config_kwargs)
|
||||
3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.)
|
||||
4. Use @pytest.mark.parametrize to create tests that call the common test methods below
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
"""
|
||||
Create a quantized model with the given config kwargs.
|
||||
|
||||
Args:
|
||||
config_kwargs: Quantization config parameters
|
||||
**extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder)
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _create_quantized_model")
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
||||
|
||||
def _is_module_quantized(self, module):
|
||||
"""
|
||||
Check if a module is quantized. Returns True if quantized, False otherwise.
|
||||
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
||||
Subclasses can override for more efficient checking.
|
||||
"""
|
||||
try:
|
||||
self._verify_if_layer_quantized("", module, {})
|
||||
return True
|
||||
except (AssertionError, AttributeError):
|
||||
return False
|
||||
|
||||
def _load_unquantized_model(self):
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {})
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _test_quantization_num_parameters(self, config_kwargs):
|
||||
model = self._load_unquantized_model()
|
||||
num_params = model.num_parameters()
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
num_params_quantized = model_quantized.num_parameters()
|
||||
|
||||
assert (
|
||||
num_params == num_params_quantized
|
||||
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
|
||||
|
||||
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
|
||||
model = self._load_unquantized_model()
|
||||
mem = model.get_memory_footprint()
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
mem_quantized = model_quantized.get_memory_footprint()
|
||||
|
||||
ratio = mem / mem_quantized
|
||||
assert (
|
||||
ratio >= expected_memory_reduction
|
||||
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
||||
|
||||
def _test_quantization_inference(self, config_kwargs):
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_quantized(**inputs)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
def _test_quantization_dtype_assignment(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.to(torch.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
device_0 = f"{torch_device}:0"
|
||||
model.to(device=device_0, dtype=torch.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.float()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.half()
|
||||
|
||||
model.to(torch_device)
|
||||
|
||||
def _test_quantization_lora_inference(self, config_kwargs):
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
except ImportError:
|
||||
pytest.skip("peft is not available")
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})")
|
||||
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None with LoRA"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
|
||||
|
||||
def _test_quantization_serialization(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_pretrained(tmpdir, safe_serialization=True)
|
||||
|
||||
model_loaded = self.model_class.from_pretrained(tmpdir)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_loaded(**inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
|
||||
|
||||
def _test_quantized_layers(self, config_kwargs):
|
||||
model_fp = self._load_unquantized_model()
|
||||
num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear))
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
num_fp32_modules = 0
|
||||
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
||||
for name, module in model_quantized.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
||||
num_fp32_modules += 1
|
||||
|
||||
expected_quantized_layers = num_linear_layers - num_fp32_modules
|
||||
|
||||
num_quantized_layers = 0
|
||||
for name, module in model_quantized.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
||||
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
||||
continue
|
||||
self._verify_if_layer_quantized(name, module, config_kwargs)
|
||||
num_quantized_layers += 1
|
||||
|
||||
assert (
|
||||
num_quantized_layers > 0
|
||||
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
|
||||
assert (
|
||||
num_quantized_layers == expected_quantized_layers
|
||||
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
"""
|
||||
Test that modules specified in modules_to_not_convert are not quantized.
|
||||
|
||||
Args:
|
||||
config_kwargs: Base quantization config kwargs
|
||||
modules_to_not_convert: List of module names to exclude from quantization
|
||||
"""
|
||||
# Create config with modules_to_not_convert
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
# Find a module that should NOT be quantized
|
||||
found_excluded = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is in the exclusion list
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(
|
||||
module
|
||||
), f"Module {name} should not be quantized but was found to be quantized"
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
|
||||
|
||||
# Find a module that SHOULD be quantized (not in exclusion list)
|
||||
found_quantized = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_not_convert):
|
||||
if self._is_module_quantized(module):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
assert found_quantized, "No quantized layers found outside of excluded modules"
|
||||
|
||||
# Compare memory footprint with fully quantized model
|
||||
model_fully_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
|
||||
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
|
||||
|
||||
assert (
|
||||
mem_with_exclusion > mem_fully_quantized
|
||||
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
||||
|
||||
def _test_quantization_device_map(self, config_kwargs):
|
||||
"""
|
||||
Test that quantized models work correctly with device_map="auto".
|
||||
|
||||
Args:
|
||||
config_kwargs: Base quantization config kwargs
|
||||
"""
|
||||
model = self._create_quantized_model(config_kwargs, device_map="auto")
|
||||
|
||||
# Verify device map is set
|
||||
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
|
||||
assert model.hf_device_map is not None, "hf_device_map should not be None"
|
||||
|
||||
# Verify inference works
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
|
||||
@is_bitsandbytes
|
||||
@nightly
|
||||
@require_accelerator
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
class BitsAndBytesTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing BitsAndBytes quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test
|
||||
|
||||
Pytest mark: bitsandbytes
|
||||
Use `pytest -m "not bitsandbytes"` to skip these tests
|
||||
"""
|
||||
|
||||
# Standard BnB configs tested for all models
|
||||
# Subclasses can override to add or modify configs
|
||||
BNB_CONFIGS = {
|
||||
"4bit_nf4": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.float16,
|
||||
},
|
||||
"4bit_fp4": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "fp4",
|
||||
"bnb_4bit_compute_dtype": torch.float16,
|
||||
},
|
||||
"8bit": {
|
||||
"load_in_8bit": True,
|
||||
},
|
||||
}
|
||||
|
||||
BNB_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"4bit_nf4": 3.0,
|
||||
"4bit_fp4": 3.0,
|
||||
"8bit": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = BitsAndBytesConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
|
||||
assert (
|
||||
module.weight.__class__ == expected_weight_class
|
||||
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_num_parameters(self, config_name):
|
||||
self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_memory_footprint(self, config_name):
|
||||
expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
||||
self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected)
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_inference(self, config_name):
|
||||
self._test_quantization_inference(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_dtype_assignment(self, config_name):
|
||||
self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_lora_inference(self, config_name):
|
||||
self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_serialization(self, config_name):
|
||||
self._test_quantization_serialization(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantized_layers(self, config_name):
|
||||
self._test_quantized_layers(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_config_serialization(self, config_name):
|
||||
model = self._create_quantized_model(self.BNB_CONFIGS[config_name])
|
||||
|
||||
assert "quantization_config" in model.config, "Missing quantization_config"
|
||||
_ = model.config["quantization_config"].to_dict()
|
||||
_ = model.config["quantization_config"].to_diff_dict()
|
||||
_ = model.config["quantization_config"].to_json_string()
|
||||
|
||||
def test_bnb_original_dtype(self):
|
||||
config_name = list(self.BNB_CONFIGS.keys())[0]
|
||||
config_kwargs = self.BNB_CONFIGS[config_name]
|
||||
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype"
|
||||
assert model.config["_pre_quantization_dtype"] in [
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
|
||||
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
|
||||
config_kwargs = self.BNB_CONFIGS["4bit_nf4"]
|
||||
|
||||
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert (
|
||||
module.weight.dtype == torch.float32
|
||||
), f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
else:
|
||||
assert (
|
||||
module.weight.dtype == torch.uint8
|
||||
), f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
|
||||
def test_bnb_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude)
|
||||
|
||||
def test_bnb_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
|
||||
|
||||
|
||||
@is_quanto
|
||||
@nightly
|
||||
@require_quanto
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
class QuantoTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing Quanto quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype
|
||||
|
||||
Pytest mark: quanto
|
||||
Use `pytest -m "not quanto"` to skip these tests
|
||||
"""
|
||||
|
||||
QUANTO_WEIGHT_TYPES = {
|
||||
"float8": {"weights_dtype": "float8"},
|
||||
"int8": {"weights_dtype": "int8"},
|
||||
"int4": {"weights_dtype": "int4"},
|
||||
"int2": {"weights_dtype": "int2"},
|
||||
}
|
||||
|
||||
QUANTO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"float8": 1.5,
|
||||
"int8": 1.5,
|
||||
"int4": 3.0,
|
||||
"int2": 7.0,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = QuantoConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_num_parameters(self, weight_type_name):
|
||||
self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_memory_footprint(self, weight_type_name):
|
||||
expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_inference(self, weight_type_name):
|
||||
self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantized_layers(self, weight_type_name):
|
||||
self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantization_lora_inference(self, weight_type_name):
|
||||
self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantization_serialization(self, weight_type_name):
|
||||
self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
def test_quanto_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude)
|
||||
|
||||
def test_quanto_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"])
|
||||
|
||||
|
||||
@is_torchao
|
||||
@require_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing TorchAO quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- TORCHAO_QUANT_TYPES: Dict of quantization type strings to test
|
||||
|
||||
Pytest mark: torchao
|
||||
Use `pytest -m "not torchao"` to skip these tests
|
||||
"""
|
||||
|
||||
TORCHAO_QUANT_TYPES = {
|
||||
"int4wo": {"quant_type": "int4_weight_only"},
|
||||
"int8wo": {"quant_type": "int8_weight_only"},
|
||||
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
|
||||
}
|
||||
|
||||
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"int4wo": 3.0,
|
||||
"int8wo": 1.5,
|
||||
"int8dq": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = TorchAoConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_num_parameters(self, quant_type):
|
||||
self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_memory_footprint(self, quant_type):
|
||||
expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_inference(self, quant_type):
|
||||
self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantized_layers(self, quant_type):
|
||||
self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantization_lora_inference(self, quant_type):
|
||||
self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantization_serialization(self, quant_type):
|
||||
self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
def test_torchao_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
# Get a module name that exists in the model - this needs to be set by test classes
|
||||
# For now, use a generic pattern that should work with transformer models
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
|
||||
@is_gguf
|
||||
@nightly
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
class GGUFTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing GGUF quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- gguf_filename: URL or path to the GGUF file
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: gguf
|
||||
Use `pytest -m "not gguf"` to skip these tests
|
||||
"""
|
||||
|
||||
gguf_filename = None
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
if config_kwargs is None:
|
||||
config_kwargs = {"compute_dtype": torch.bfloat16}
|
||||
|
||||
config = GGUFQuantizationConfig(**config_kwargs)
|
||||
kwargs = {
|
||||
"quantization_config": config,
|
||||
"torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16),
|
||||
}
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_single_file(self.gguf_filename, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs=None):
|
||||
from diffusers.quantizers.gguf.utils import GGUFParameter
|
||||
|
||||
assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter"
|
||||
assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type"
|
||||
assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8"
|
||||
|
||||
def test_gguf_quantization_inference(self):
|
||||
self._test_quantization_inference({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
|
||||
_keep_in_fp32_modules = self.model_class._keep_in_fp32_modules
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model()
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, f"Module {name} should be FP32"
|
||||
finally:
|
||||
self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||
|
||||
def test_gguf_quantization_dtype_assignment(self):
|
||||
self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_quantization_lora_inference(self):
|
||||
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_dequantize_model(self):
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
model = self._create_quantized_model()
|
||||
model.dequantize()
|
||||
|
||||
def _check_for_gguf_linear(model):
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return
|
||||
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
|
||||
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
|
||||
|
||||
for name, module in model.named_children():
|
||||
_check_for_gguf_linear(module)
|
||||
|
||||
def test_gguf_quantized_layers(self):
|
||||
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
|
||||
|
||||
|
||||
@is_modelopt
|
||||
@nightly
|
||||
@require_accelerator
|
||||
@require_accelerate
|
||||
@require_modelopt_version_greater_or_equal("0.33.1")
|
||||
class ModelOptTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing NVIDIA ModelOpt quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test
|
||||
|
||||
Pytest mark: modelopt
|
||||
Use `pytest -m "not modelopt"` to skip these tests
|
||||
"""
|
||||
|
||||
MODELOPT_CONFIGS = {
|
||||
"fp8": {"quant_type": "FP8"},
|
||||
"int8": {"quant_type": "INT8"},
|
||||
"int4": {"quant_type": "INT4"},
|
||||
}
|
||||
|
||||
MODELOPT_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"fp8": 1.5,
|
||||
"int8": 1.5,
|
||||
"int4": 3.0,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = NVIDIAModelOptConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_num_parameters(self, config_name):
|
||||
self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
|
||||
def test_modelopt_quantization_memory_footprint(self, config_name):
|
||||
expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
|
||||
def test_modelopt_quantization_inference(self, config_name):
|
||||
self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_dtype_assignment(self, config_name):
|
||||
self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_lora_inference(self, config_name):
|
||||
self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_serialization(self, config_name):
|
||||
self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantized_layers(self, config_name):
|
||||
self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
def test_modelopt_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude)
|
||||
|
||||
def test_modelopt_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"])
|
||||
@@ -0,0 +1,247 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
- subfolder: (Optional) Subfolder within the repo
|
||||
- torch_dtype: (Optional) torch dtype to use for testing
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
pretrained_model_name_or_path = None
|
||||
ckpt_path = None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_single_file = model_single_file.state_dict()
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
|
||||
f"Parameter values differ for {key}: "
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {}
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
@@ -0,0 +1,224 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
|
||||
|
||||
|
||||
@is_training
|
||||
@require_torch_accelerator_with_training
|
||||
class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Expected properties to be implemented by subclasses:
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_training_with_ema(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model.parameters())
|
||||
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if expected_set is None:
|
||||
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} "
|
||||
f"do not match expected set {expected_set}"
|
||||
)
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if skip is None:
|
||||
skip = set()
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict)
|
||||
if isinstance(out, dict):
|
||||
out = out.sample if hasattr(out, "sample") else out.to_tuple()[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy)
|
||||
if isinstance(out_2, dict):
|
||||
out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
assert (
|
||||
loss - loss_2
|
||||
).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert torch_all_close(
|
||||
param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol
|
||||
), f"Gradient mismatch for {name}"
|
||||
|
||||
def test_mixed_precision_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# Test with float16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Test with bfloat16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
model.zero_grad()
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
@@ -0,0 +1,316 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig:
|
||||
model_class = FluxTransformer2DModel
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
|
||||
def get_init_dict(self):
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 32,
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
)
|
||||
|
||||
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
continue
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=(
|
||||
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
|
||||
),
|
||||
num_image_text_embeds=4,
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = {}
|
||||
sd = image_projection.state_dict()
|
||||
ip_image_projection_state_dict.update(
|
||||
{
|
||||
"proj.weight": sd["image_embeds.weight"],
|
||||
"proj.bias": sd["image_embeds.bias"],
|
||||
"norm.weight": sd["norm.weight"],
|
||||
"norm.bias": sd["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||
subfolder = "transformer"
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionConfig,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
ChronoEditPipeline,
|
||||
ChronoEditTransformer3DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class ChronoEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ChronoEditPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# TODO: impl FlowDPMSolverMultistepScheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = ChronoEditTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
projection_dim=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
image_size=32,
|
||||
intermediate_size=16,
|
||||
patch_size=1,
|
||||
)
|
||||
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_processor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image_height = 16
|
||||
image_width = 16
|
||||
image = Image.new("RGB", (image_width, image_height))
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative", # TODO
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"num_frames": 5,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4525, 0.4520, 0.4485, 0.4534, 0.4523, 0.4522, 0.4529, 0.4528, 0.5022, 0.5064, 0.5011, 0.5061, 0.5028, 0.4979, 0.5117, 0.5192])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"ChronoEditPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
|
||||
)
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
@@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
assert torch.allclose(output_native, output_cuda, 1e-2), (
|
||||
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
)
|
||||
assert torch.allclose(
|
||||
output_native, output_cuda, 1e-2
|
||||
), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
|
||||
|
||||
@nightly
|
||||
|
||||
+199
-83
@@ -13,7 +13,6 @@ import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from collections import UserDict
|
||||
from contextlib import contextmanager
|
||||
@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import pytest
|
||||
import requests
|
||||
from numpy.linalg import norm
|
||||
from packaging import version
|
||||
@@ -241,7 +241,6 @@ def parse_flag_from_env(key, default=False):
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
||||
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
@@ -267,7 +266,7 @@ def slow(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
|
||||
|
||||
|
||||
def nightly(test_case):
|
||||
@@ -277,33 +276,149 @@ def nightly(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
|
||||
return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
|
||||
|
||||
|
||||
def is_torch_compile(test_case):
|
||||
"""
|
||||
Decorator marking a test that runs compile tests in the diffusers CI.
|
||||
|
||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
||||
|
||||
Decorator marking a test as a torch.compile test. These tests can be filtered using:
|
||||
pytest -m "not compile" to skip
|
||||
pytest -m compile to run only these tests
|
||||
"""
|
||||
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
|
||||
return pytest.mark.compile(test_case)
|
||||
|
||||
|
||||
def is_single_file(test_case):
|
||||
"""
|
||||
Decorator marking a test as a single file loading test. These tests can be filtered using:
|
||||
pytest -m "not single_file" to skip
|
||||
pytest -m single_file to run only these tests
|
||||
"""
|
||||
return pytest.mark.single_file(test_case)
|
||||
|
||||
|
||||
def is_lora(test_case):
|
||||
"""
|
||||
Decorator marking a test as a LoRA test. These tests can be filtered using:
|
||||
pytest -m "not lora" to skip
|
||||
pytest -m lora to run only these tests
|
||||
"""
|
||||
return pytest.mark.lora(test_case)
|
||||
|
||||
|
||||
def is_ip_adapter(test_case):
|
||||
"""
|
||||
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
|
||||
pytest -m "not ip_adapter" to skip
|
||||
pytest -m ip_adapter to run only these tests
|
||||
"""
|
||||
return pytest.mark.ip_adapter(test_case)
|
||||
|
||||
|
||||
def is_training(test_case):
|
||||
"""
|
||||
Decorator marking a test as a training test. These tests can be filtered using:
|
||||
pytest -m "not training" to skip
|
||||
pytest -m training to run only these tests
|
||||
"""
|
||||
return pytest.mark.training(test_case)
|
||||
|
||||
|
||||
def is_attention(test_case):
|
||||
"""
|
||||
Decorator marking a test as an attention test. These tests can be filtered using:
|
||||
pytest -m "not attention" to skip
|
||||
pytest -m attention to run only these tests
|
||||
"""
|
||||
return pytest.mark.attention(test_case)
|
||||
|
||||
|
||||
def is_memory(test_case):
|
||||
"""
|
||||
Decorator marking a test as a memory optimization test. These tests can be filtered using:
|
||||
pytest -m "not memory" to skip
|
||||
pytest -m memory to run only these tests
|
||||
"""
|
||||
return pytest.mark.memory(test_case)
|
||||
|
||||
|
||||
def is_cpu_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a CPU offload test. These tests can be filtered using:
|
||||
pytest -m "not cpu_offload" to skip
|
||||
pytest -m cpu_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.cpu_offload(test_case)
|
||||
|
||||
|
||||
def is_group_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a group offload test. These tests can be filtered using:
|
||||
pytest -m "not group_offload" to skip
|
||||
pytest -m group_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.group_offload(test_case)
|
||||
|
||||
|
||||
def is_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||
pytest -m "not bitsandbytes" to skip
|
||||
pytest -m bitsandbytes to run only these tests
|
||||
"""
|
||||
return pytest.mark.bitsandbytes(test_case)
|
||||
|
||||
|
||||
def is_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
|
||||
pytest -m "not quanto" to skip
|
||||
pytest -m quanto to run only these tests
|
||||
"""
|
||||
return pytest.mark.quanto(test_case)
|
||||
|
||||
|
||||
def is_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
|
||||
pytest -m "not torchao" to skip
|
||||
pytest -m torchao to run only these tests
|
||||
"""
|
||||
return pytest.mark.torchao(test_case)
|
||||
|
||||
|
||||
def is_gguf(test_case):
|
||||
"""
|
||||
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
|
||||
pytest -m "not gguf" to skip
|
||||
pytest -m gguf to run only these tests
|
||||
"""
|
||||
return pytest.mark.gguf(test_case)
|
||||
|
||||
|
||||
def is_modelopt(test_case):
|
||||
"""
|
||||
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
|
||||
pytest -m "not modelopt" to skip
|
||||
pytest -m modelopt to run only these tests
|
||||
"""
|
||||
return pytest.mark.modelopt(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_version_greater_equal(torch_version):
|
||||
@@ -311,8 +426,9 @@ def require_torch_version_greater_equal(torch_version):
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
|
||||
return pytest.mark.skipif(
|
||||
not correct_torch_version,
|
||||
reason=f"test requires torch with the version greater than or equal to {torch_version}",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -323,8 +439,8 @@ def require_torch_version_greater(torch_version):
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
|
||||
return pytest.mark.skipif(
|
||||
not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -332,19 +448,18 @@ def require_torch_version_greater(torch_version):
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
|
||||
|
||||
|
||||
def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
def decorator(test_case):
|
||||
if torch.cuda.is_available():
|
||||
current_compute_capability = get_torch_cuda_device_capability()
|
||||
return unittest.skipUnless(
|
||||
float(current_compute_capability) == float(expected_compute_capability),
|
||||
"Test not supported for this compute capability.",
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
float(current_compute_capability) != float(expected_compute_capability),
|
||||
reason="Test not supported for this compute capability.",
|
||||
)(test_case)
|
||||
return test_case
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -352,9 +467,7 @@ def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
@@ -364,11 +477,11 @@ def require_torch_multi_gpu(test_case):
|
||||
-k "multi_gpu"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
@@ -377,27 +490,28 @@ def require_torch_multi_accelerator(test_case):
|
||||
without multiple hardware accelerators.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(
|
||||
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
|
||||
return pytest.mark.skipif(
|
||||
not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
|
||||
reason="test requires multiple hardware accelerators",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp64(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_big_gpu_with_torch_cuda(test_case):
|
||||
@@ -406,17 +520,17 @@ def require_big_gpu_with_torch_cuda(test_case):
|
||||
etc.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
return pytest.mark.skipif(
|
||||
total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
)(test_case)
|
||||
|
||||
|
||||
@@ -430,12 +544,12 @@ def require_big_accelerator(test_case):
|
||||
test_case = pytest.mark.big_accelerator(test_case)
|
||||
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
|
||||
|
||||
if torch.xpu.is_available():
|
||||
device_properties = torch.xpu.get_device_properties(0)
|
||||
@@ -443,30 +557,30 @@ def require_big_accelerator(test_case):
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY,
|
||||
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
return pytest.mark.skipif(
|
||||
total_memory < BIG_GPU_MEMORY,
|
||||
reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
is_torch_available() and backend_supports_training(torch_device),
|
||||
"test requires accelerator with training support",
|
||||
return pytest.mark.skipif(
|
||||
not (is_torch_available() and backend_supports_training(torch_device)),
|
||||
reason="test requires accelerator with training support",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def skip_mps(test_case):
|
||||
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
||||
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
||||
return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
"""
|
||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def require_compel(test_case):
|
||||
@@ -474,21 +588,21 @@ def require_compel(test_case):
|
||||
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
|
||||
the library is not installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
|
||||
return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
|
||||
|
||||
|
||||
def require_onnxruntime(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
||||
return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
|
||||
|
||||
|
||||
def require_note_seq(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
|
||||
return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
|
||||
|
||||
|
||||
def require_accelerator(test_case):
|
||||
@@ -496,14 +610,14 @@ def require_accelerator(test_case):
|
||||
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
|
||||
hardware accelerator available.
|
||||
"""
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
|
||||
return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
|
||||
|
||||
|
||||
def require_torchsde(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
|
||||
return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
|
||||
|
||||
|
||||
def require_peft_backend(test_case):
|
||||
@@ -511,35 +625,35 @@ def require_peft_backend(test_case):
|
||||
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
|
||||
transformers.
|
||||
"""
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def require_timm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
|
||||
return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
|
||||
return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
|
||||
|
||||
|
||||
def require_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
|
||||
return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
@@ -552,8 +666,8 @@ def require_peft_version_greater(peft_version):
|
||||
correct_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse(peft_version)
|
||||
return unittest.skipUnless(
|
||||
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
return pytest.mark.skipif(
|
||||
not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -569,9 +683,9 @@ def require_transformers_version_greater(transformers_version):
|
||||
correct_transformers_version = is_transformers_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse(transformers_version)
|
||||
return unittest.skipUnless(
|
||||
correct_transformers_version,
|
||||
f"test requires transformers with the version greater than {transformers_version}",
|
||||
return pytest.mark.skipif(
|
||||
not correct_transformers_version,
|
||||
reason=f"test requires transformers with the version greater than {transformers_version}",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -582,8 +696,9 @@ def require_accelerate_version_greater(accelerate_version):
|
||||
correct_accelerate_version = is_accelerate_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("accelerate")).base_version
|
||||
) > version.parse(accelerate_version)
|
||||
return unittest.skipUnless(
|
||||
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_accelerate_version,
|
||||
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -594,8 +709,8 @@ def require_bitsandbytes_version_greater(bnb_version):
|
||||
correct_bnb_version = is_bitsandbytes_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("bitsandbytes")).base_version
|
||||
) > version.parse(bnb_version)
|
||||
return unittest.skipUnless(
|
||||
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -606,8 +721,9 @@ def require_hf_hub_version_greater(hf_hub_version):
|
||||
correct_hf_hub_version = version.parse(
|
||||
version.parse(importlib.metadata.version("huggingface_hub")).base_version
|
||||
) > version.parse(hf_hub_version)
|
||||
return unittest.skipUnless(
|
||||
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_hf_hub_version,
|
||||
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -618,8 +734,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
|
||||
correct_gguf_version = is_gguf_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("gguf")).base_version
|
||||
) >= version.parse(gguf_version)
|
||||
return unittest.skipUnless(
|
||||
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -630,8 +746,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("torchao")).base_version
|
||||
) >= version.parse(torchao_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -642,8 +758,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
||||
correct_kernels_version = is_kernels_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("kernels")).base_version
|
||||
) >= version.parse(kernels_version)
|
||||
return unittest.skipUnless(
|
||||
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -653,7 +769,7 @@ def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
"""
|
||||
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
|
||||
return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
|
||||
|
||||
|
||||
def get_python_version():
|
||||
@@ -1064,8 +1180,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
test_case:
|
||||
The test case object that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
@@ -1083,7 +1199,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
# We can't send test case objects to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
|
||||
Reference in New Issue
Block a user