Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul 4209647dc8 Merge branch 'main' into fix-bnb-test 2025-04-18 08:55:15 +05:30
Marc Sun c1a2a9d405 fix 2025-04-17 18:00:25 +02:00
15 changed files with 45 additions and 389 deletions
-54
View File
@@ -133,60 +133,6 @@ output = pipe(
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### First and Last Frame Interpolation
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipe(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```
### Video to Video Generation
```python
@@ -1915,22 +1915,17 @@ def main(args):
free_memory()
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1954,6 +1949,7 @@ def main(args):
lr_scheduler,
)
else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
)
@@ -1965,14 +1961,8 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -33,6 +33,7 @@ from diffusers import DiffusionPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
@@ -299,7 +300,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class StableDiffusionXLControlNetAdapterInpaintPipeline(
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+7 -18
View File
@@ -1407,22 +1407,17 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1449,14 +1444,8 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -1524,22 +1524,17 @@ def main(args):
free_memory()
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1566,14 +1561,8 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -1523,22 +1523,17 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1555,14 +1550,7 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+2 -43
View File
@@ -39,24 +39,6 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# for the FLF2V model
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
@@ -153,28 +135,6 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
elif model_type == "Wan-FLF2V-14B-720P":
config = {
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": 257 * 2,
},
}
return config
@@ -433,12 +393,11 @@ if __name__ == "__main__":
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
)
if "I2V" in args.model_type or "FLF2V" in args.model_type:
if "I2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
+1 -47
View File
@@ -18,7 +18,7 @@ import importlib
import inspect
import os
from array import array
from collections import OrderedDict, defaultdict
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,7 +38,6 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerator_device,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -305,51 +304,6 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index
# Taken from
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
def _expand_device_map(device_map, param_names):
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
}
if not len(accelerator_device_map):
return
total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
total_byte_count[device] += param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, byte_count in total_byte_count.items():
if device.type == "cuda":
index = device.index if device.index is not None else torch.cuda.current_device()
device_memory = torch.cuda.mem_get_info(index)[0]
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
# Allocate memory
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
-25
View File
@@ -63,9 +63,7 @@ from ..utils.hub_utils import (
populate_model_card,
)
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
@@ -1376,24 +1374,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return super().float(*args)
# Taken from `transformers`.
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
def get_parameter_or_buffer(self, target: str):
"""
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
of the model.
"""
try:
return self.get_parameter(target)
except AttributeError:
pass
try:
return self.get_buffer(target)
except AttributeError:
pass
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
@classmethod
def _load_pretrained_model(
cls,
@@ -1430,11 +1410,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = None
error_msgs = []
# Optionally, warmup cuda to load the weights much faster on devices
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -918,5 +918,5 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return (output, hidden_states_masks)
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)
@@ -49,10 +49,8 @@ class WanAttnProcessor2_0:
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -110,23 +108,14 @@ class WanAttnProcessor2_0:
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
@@ -141,7 +130,6 @@ class WanTimeTextImageEmbedding(nn.Module):
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
@@ -153,7 +141,7 @@ class WanTimeTextImageEmbedding(nn.Module):
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
def forward(
self,
@@ -362,7 +350,6 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()
@@ -381,7 +368,6 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
@@ -380,7 +380,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
@@ -399,16 +398,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents = latents.to(device=device, dtype=dtype)
image = image.unsqueeze(2)
if last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
else:
last_image = last_image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2,
)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
video_condition = video_condition.to(device=device, dtype=dtype)
latents_mean = (
@@ -432,11 +424,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None:
mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
@@ -488,7 +476,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -633,10 +620,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
if image_embeds is None:
if last_image is None:
image_embeds = self.encode_image(image, device)
else:
image_embeds = self.encode_image([image, last_image], device)
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
@@ -647,10 +631,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
# 5. Prepare latent variables
num_channels_latents = self.vae.config.z_dim
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
if last_image is not None:
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
@@ -662,7 +642,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device,
generator,
latents,
last_image,
)
# 6. Denoising loop
-1
View File
@@ -129,7 +129,6 @@ from .state_dict_utils import (
convert_unet_state_dict_to_peft,
state_dict_all_zero,
)
from .testing_utils import is_accelerator_device
from .typing_utils import _get_detailed_type, _is_valid_type
-12
View File
@@ -1289,18 +1289,6 @@ if is_torch_available():
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
if is_torch_available():
# Taken from
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
a proper `torch.device`.
"""
if device == "disk":
return False
else:
return torch.device(device).type not in ["meta", "cpu"]
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
# Type definition of key used in `Expectations` class.
@@ -160,90 +160,3 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
# TODO: impl FlowDPMSolverMultistepScheduler
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
image_dim=4,
pos_embed_seq_len=2 * (4 * 4 + 1),
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
projection_dim=4,
num_hidden_layers=2,
num_attention_heads=2,
image_size=4,
intermediate_size=16,
patch_size=1,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
torch.manual_seed(0)
image_processor = CLIPImageProcessor(crop_size=4, size=4)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 16
image_width = 16
image = Image.new("RGB", (image_width, image_height))
last_image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"last_image": last_image,
"prompt": "dance monkey",
"negative_prompt": "negative",
"height": image_height,
"width": image_width,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs