Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e187e239d6 | |||
| 59f4531a55 | |||
| ff80e8a27f | |||
| fe176857a2 |
@@ -28,9 +28,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio\
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
onnxruntime \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
|
||||
+30
-30
@@ -290,12 +290,12 @@
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/consisid_transformer3d
|
||||
title: ConsisIDTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
title: CogView3PlusTransformer2DModel
|
||||
- local: api/models/cogview4_transformer2d
|
||||
title: CogView4Transformer2DModel
|
||||
- local: api/models/consisid_transformer3d
|
||||
title: ConsisIDTransformer3DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
@@ -310,12 +310,12 @@
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
title: Lumina2Transformer2DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
title: Lumina2Transformer2DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/mochi_transformer3d
|
||||
title: MochiTransformer3DModel
|
||||
- local: api/models/omnigen_transformer
|
||||
@@ -324,10 +324,10 @@
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/stable_audio_transformer
|
||||
title: StableAudioDiTModel
|
||||
- local: api/models/transformer2d
|
||||
@@ -342,10 +342,10 @@
|
||||
title: StableCascadeUNet
|
||||
- local: api/models/unet
|
||||
title: UNet1DModel
|
||||
- local: api/models/unet2d-cond
|
||||
title: UNet2DConditionModel
|
||||
- local: api/models/unet2d
|
||||
title: UNet2DModel
|
||||
- local: api/models/unet2d-cond
|
||||
title: UNet2DConditionModel
|
||||
- local: api/models/unet3d-cond
|
||||
title: UNet3DConditionModel
|
||||
- local: api/models/unet-motion
|
||||
@@ -354,10 +354,6 @@
|
||||
title: UViT2DModel
|
||||
title: UNets
|
||||
- sections:
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
title: AutoencoderDC
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/autoencoderkl_allegro
|
||||
@@ -374,6 +370,10 @@
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
title: AutoencoderDC
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
@@ -521,40 +521,40 @@
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
- local: api/pipelines/stable_diffusion/img2img
|
||||
title: Image-to-image
|
||||
- local: api/pipelines/stable_diffusion/svd
|
||||
title: Image-to-video
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
title: Stable Diffusion
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
|
||||
@@ -25,8 +25,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
|
||||
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
|
||||
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
|
||||
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
@@ -79,14 +77,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
@@ -89,23 +89,6 @@ image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
## Support for `torch.compile()`
|
||||
|
||||
AuraFlow can be compiled with `torch.compile()` to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from [here](https://pytorch.org/). The snippet below shows the changes needed to enable this:
|
||||
|
||||
```diff
|
||||
+ torch.fx.experimental._config.use_duck_shape = False
|
||||
+ pipeline.transformer = torch.compile(
|
||||
pipeline.transformer, fullgraph=True, dynamic=True
|
||||
)
|
||||
```
|
||||
|
||||
Specifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
|
||||
This enables from 100% (on low resolutions) to a 30% (on 1536x1536 resolution) speed improvements.
|
||||
|
||||
Thanks to [AstraliteHeart](https://github.com/huggingface/diffusers/pull/11297/) who helped us rewrite the [`AuraFlowTransformer2DModel`] class so that the above works for different resolutions ([PR](https://github.com/huggingface/diffusers/pull/11297/)).
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
|
||||
@@ -133,60 +133,6 @@ output = pipe(
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### First and Last Frame Interpolation
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
|
||||
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
|
||||
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
def center_crop_resize(image, height, width):
|
||||
# Calculate resize ratio to match first frame dimensions
|
||||
resize_ratio = max(width / image.width, height / image.height)
|
||||
|
||||
# Resize the image
|
||||
width = round(image.width * resize_ratio)
|
||||
height = round(image.height * resize_ratio)
|
||||
size = [width, height]
|
||||
image = TF.center_crop(image, size)
|
||||
|
||||
return image, height, width
|
||||
|
||||
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
|
||||
if last_frame.size != first_frame.size:
|
||||
last_frame, _, _ = center_crop_resize(last_frame, height, width)
|
||||
|
||||
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
|
||||
output = pipe(
|
||||
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Video to Video Generation
|
||||
|
||||
```python
|
||||
|
||||
@@ -1915,22 +1915,17 @@ def main(args):
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1954,6 +1949,7 @@ def main(args):
|
||||
lr_scheduler,
|
||||
)
|
||||
else:
|
||||
print("I SHOULD BE HERE")
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
@@ -1965,14 +1961,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers import DiffusionPipeline
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
@@ -299,7 +300,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
|
||||
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
|
||||
|
||||
@@ -1407,22 +1407,17 @@ def main(args):
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1449,14 +1444,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -1524,22 +1524,17 @@ def main(args):
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1566,14 +1561,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -1523,22 +1523,17 @@ def main(args):
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1555,14 +1550,7 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
@@ -39,24 +39,6 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# for the FLF2V model
|
||||
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
@@ -153,28 +135,6 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-FLF2V-14B-720P":
|
||||
config = {
|
||||
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"rope_max_seq_len": 1024,
|
||||
"pos_embed_seq_len": 257 * 2,
|
||||
},
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@@ -433,12 +393,11 @@ if __name__ == "__main__":
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
|
||||
)
|
||||
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
if "I2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -526,7 +526,7 @@ class FluxIPAdapterMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=image_encoder_dtype,
|
||||
dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
|
||||
@@ -127,7 +127,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -154,7 +154,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
@@ -368,8 +368,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -430,8 +451,29 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -583,7 +625,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -610,8 +651,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -650,7 +689,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -661,7 +699,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -672,7 +709,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -823,8 +859,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -886,8 +943,29 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -1170,8 +1248,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -1246,8 +1345,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -1303,8 +1423,29 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -1560,11 +1701,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -1582,8 +1719,6 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -1613,7 +1748,6 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1637,8 +1771,29 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -1921,7 +2076,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -1940,16 +2095,34 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
|
||||
adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
|
||||
additional method before loading the adapter:
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -2071,8 +2244,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
@@ -2182,8 +2376,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -2643,8 +2858,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
@@ -2700,8 +2936,29 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
@@ -2878,11 +3135,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -2900,8 +3153,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -2931,7 +3182,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2955,8 +3205,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -3195,11 +3466,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -3217,8 +3484,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -3248,7 +3513,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3272,8 +3536,29 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -3514,11 +3799,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -3536,8 +3817,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -3567,7 +3846,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3591,8 +3869,29 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -3833,11 +4132,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -3855,8 +4150,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -3886,7 +4179,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -3910,8 +4202,29 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -4155,11 +4468,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -4177,8 +4486,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -4208,7 +4515,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -4232,8 +4538,29 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -4478,11 +4805,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -4500,8 +4823,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -4531,7 +4852,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -4555,8 +4875,29 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -4826,11 +5167,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -4848,8 +5185,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -4883,7 +5218,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -4907,8 +5241,29 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
@@ -5149,11 +5504,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
@@ -5171,8 +5522,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -5202,7 +5551,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -5226,8 +5574,29 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
|
||||
@@ -430,7 +430,7 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
||||
@@ -454,18 +454,17 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
|
||||
@@ -18,7 +18,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
@@ -38,7 +38,6 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerator_device,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
@@ -305,51 +304,6 @@ def load_model_dict_into_meta(
|
||||
return offload_index, state_dict_index
|
||||
|
||||
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
|
||||
def _expand_device_map(device_map, param_names):
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
|
||||
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
|
||||
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
|
||||
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
|
||||
}
|
||||
if not len(accelerator_device_map):
|
||||
return
|
||||
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
total_byte_count[device] += param_byte_count
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, byte_count in total_byte_count.items():
|
||||
if device.type == "cuda":
|
||||
index = device.index if device.index is not None else torch.cuda.current_device()
|
||||
device_memory = torch.cuda.mem_get_info(index)[0]
|
||||
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
|
||||
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
|
||||
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
|
||||
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
|
||||
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
|
||||
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
||||
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
||||
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
|
||||
@@ -63,9 +63,7 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
_expand_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
@@ -1376,24 +1374,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
# Taken from `transformers`.
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
||||
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
|
||||
of the model.
|
||||
"""
|
||||
try:
|
||||
return self.get_parameter(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
return self.get_buffer(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
@@ -1430,11 +1410,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
# Optionally, warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
|
||||
@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -686,108 +686,46 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
patch_size = self.config.patch_size
|
||||
patch_height, patch_width = height // patch_size, width // patch_size
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# create img_sizes
|
||||
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
|
||||
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# create hidden_states_masks
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
|
||||
hidden_states_masks[:, : patch_height * patch_width] = 1.0
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
pz2 = self.config.patch_size * self.config.patch_size
|
||||
if isinstance(x, torch.Tensor):
|
||||
B, C = x.shape[0], x.shape[1]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
else:
|
||||
hidden_states_masks = None
|
||||
B, C = len(x), x[0].shape[0]
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
||||
|
||||
# create img_ids
|
||||
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
|
||||
row_indices = torch.arange(patch_height, device=device)[:, None]
|
||||
col_indices = torch.arange(patch_width, device=device)[None, :]
|
||||
img_ids[..., 1] = img_ids[..., 1] + row_indices
|
||||
img_ids[..., 2] = img_ids[..., 2] + col_indices
|
||||
img_ids = img_ids.reshape(patch_height * patch_width, -1)
|
||||
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
# Handle non-square latents
|
||||
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
|
||||
img_ids_pad[: patch_height * patch_width, :] = img_ids
|
||||
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
if img_sizes is not None:
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
|
||||
B, C, S, _ = x.shape
|
||||
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
B, C, Hp1, Wp2 = x.shape
|
||||
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
|
||||
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
|
||||
x = x.permute(0, 2, 4, 3, 5, 1)
|
||||
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
|
||||
img_sizes = [[pH, pW]] * B
|
||||
x_masks = None
|
||||
else:
|
||||
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
# patchify hidden_states
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
# Handle non-square latents
|
||||
out = torch.zeros(
|
||||
(batch_size, channels, self.max_seq, patch_size * patch_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height, patch_size, patch_width, patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height * patch_width, patch_size * patch_size
|
||||
)
|
||||
out[:, :, 0 : patch_height * patch_width] = hidden_states
|
||||
hidden_states = out
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
||||
batch_size, self.max_seq, patch_size * patch_size * channels
|
||||
)
|
||||
|
||||
else:
|
||||
# Handle square latents
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height, patch_size, patch_width, patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, patch_height * patch_width, patch_size * patch_size * channels
|
||||
)
|
||||
|
||||
return hidden_states, hidden_states_masks, img_sizes, img_ids
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timesteps: torch.LongTensor = None,
|
||||
encoder_hidden_states_t5: torch.Tensor = None,
|
||||
encoder_hidden_states_llama3: torch.Tensor = None,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_embeds: torch.Tensor = None,
|
||||
img_ids: Optional[torch.Tensor] = None,
|
||||
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
img_ids: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
||||
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
||||
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
||||
|
||||
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
|
||||
deprecation_message = (
|
||||
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
||||
)
|
||||
deprecate("img_ids", "0.34.0", deprecation_message)
|
||||
|
||||
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
|
||||
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
|
||||
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
||||
raise ValueError(
|
||||
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
||||
)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -807,19 +745,42 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
|
||||
# Patchify the input
|
||||
if hidden_states_masks is None:
|
||||
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
|
||||
|
||||
# Embed the hidden states
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
B, C, H, W = hidden_states.shape
|
||||
patch_size = self.config.patch_size
|
||||
pH, pW = H // patch_size, W // patch_size
|
||||
out = torch.zeros(
|
||||
(B, C, self.max_seq, patch_size * patch_size),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
|
||||
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
|
||||
out[:, :, 0 : pH * pW] = hidden_states
|
||||
hidden_states = out
|
||||
|
||||
# 0. time
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
temb = timesteps + p_embedder
|
||||
|
||||
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
|
||||
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
||||
if hidden_states_masks is None:
|
||||
pH, pW = img_sizes[0]
|
||||
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
||||
img_ids = (
|
||||
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
|
||||
.unsqueeze(0)
|
||||
.repeat(batch_size, 1, 1)
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||
encoder_hidden_states = encoder_hidden_states[-1]
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
@@ -828,9 +789,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||
new_encoder_hidden_states.append(enc_hidden_state)
|
||||
encoder_hidden_states = new_encoder_hidden_states
|
||||
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(encoder_hidden_states_t5)
|
||||
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
||||
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(T5_encoder_hidden_states)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
batch_size,
|
||||
@@ -918,5 +879,5 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return (output, hidden_states_masks)
|
||||
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)
|
||||
|
||||
@@ -49,10 +49,8 @@ class WanAttnProcessor2_0:
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
# 512 is the context length of the text encoder, hardcoded for now
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
||||
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
@@ -110,23 +108,14 @@ class WanAttnProcessor2_0:
|
||||
|
||||
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = FP32LayerNorm(in_features)
|
||||
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
||||
self.norm2 = FP32LayerNorm(out_features)
|
||||
if pos_embed_seq_len is not None:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
||||
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
||||
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
||||
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
@@ -141,7 +130,6 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -153,7 +141,7 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -362,7 +350,6 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -381,7 +368,6 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
|
||||
@@ -15,7 +15,7 @@ from transformers import (
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HiDreamImagePipelineOutput
|
||||
@@ -38,6 +38,9 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
||||
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
|
||||
|
||||
>>> scheduler = UniPCMultistepScheduler(
|
||||
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
|
||||
... )
|
||||
|
||||
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
@@ -49,6 +52,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> pipe = HiDreamImagePipeline.from_pretrained(
|
||||
... "HiDream-ai/HiDream-I1-Full",
|
||||
... scheduler=scheduler,
|
||||
... tokenizer_4=tokenizer_4,
|
||||
... text_encoder_4=text_encoder_4,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
@@ -144,7 +148,7 @@ def retrieve_timesteps(
|
||||
|
||||
class HiDreamImagePipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -305,10 +309,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_4: Optional[Union[str, List[str]]] = None,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
prompt_4: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
@@ -317,10 +321,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_4: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
|
||||
prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
@@ -330,177 +332,120 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = pooled_prompt_embeds.shape[0]
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
|
||||
|
||||
device = device or self._execution_device
|
||||
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
prompt_4=prompt_4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
negative_prompt_4 = negative_prompt_4 or negative_prompt
|
||||
|
||||
if len(negative_prompt) > 1 and len(negative_prompt) != batch_size:
|
||||
raise ValueError(f"negative_prompt must be of length 1 or {batch_size}")
|
||||
|
||||
negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
negative_prompt_4 = (
|
||||
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
|
||||
)
|
||||
|
||||
if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1:
|
||||
negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1)
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_3=negative_prompt_3,
|
||||
prompt_4=negative_prompt_4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
prompt_4: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
if len(prompt_2) > 1 and len(prompt_2) != batch_size:
|
||||
raise ValueError(f"prompt_2 must be of length 1 or {batch_size}")
|
||||
|
||||
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
|
||||
)
|
||||
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
|
||||
pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1)
|
||||
|
||||
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
|
||||
if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size:
|
||||
raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}")
|
||||
|
||||
negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
|
||||
negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
|
||||
)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
if prompt_embeds_t5 is None:
|
||||
if prompt_embeds is None:
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
if len(prompt_3) > 1 and len(prompt_3) != batch_size:
|
||||
raise ValueError(f"prompt_3 must be of length 1 or {batch_size}")
|
||||
|
||||
prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
|
||||
|
||||
if prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
|
||||
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds_t5 is None:
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
|
||||
if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size:
|
||||
raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}")
|
||||
|
||||
negative_prompt_embeds_t5 = self._get_t5_prompt_embeds(
|
||||
negative_prompt_3, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
|
||||
if prompt_embeds_llama3 is None:
|
||||
prompt_4 = prompt_4 or prompt
|
||||
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
|
||||
|
||||
if len(prompt_4) > 1 and len(prompt_4) != batch_size:
|
||||
raise ValueError(f"prompt_4 must be of length 1 or {batch_size}")
|
||||
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
|
||||
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
|
||||
|
||||
prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
|
||||
_, seq_len, _ = t5_prompt_embeds.shape
|
||||
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
_, _, seq_len, dim = llama3_prompt_embeds.shape
|
||||
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
|
||||
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None:
|
||||
negative_prompt_4 = negative_prompt_4 or negative_prompt
|
||||
negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
|
||||
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
|
||||
|
||||
if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size:
|
||||
raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}")
|
||||
|
||||
negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds(
|
||||
negative_prompt_4, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
|
||||
# duplicate pooled_prompt_embeds for each generation per prompt
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# duplicate t5_prompt_embeds for batch_size and num_images_per_prompt
|
||||
bs_embed, seq_len, _ = prompt_embeds_t5.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}")
|
||||
prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt
|
||||
_, bs_embed, seq_len, dim = prompt_embeds_llama3.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}")
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
|
||||
bs_embed, seq_len = negative_pooled_prompt_embeds.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}")
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt
|
||||
bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}")
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt
|
||||
_, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}")
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view(
|
||||
-1, batch_size * num_images_per_prompt, seq_len, dim
|
||||
)
|
||||
|
||||
return (
|
||||
prompt_embeds_t5,
|
||||
negative_prompt_embeds_t5,
|
||||
prompt_embeds_llama3,
|
||||
negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
@@ -531,115 +476,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
prompt_4,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
negative_prompt_4=None,
|
||||
prompt_embeds_t5=None,
|
||||
prompt_embeds_llama3=None,
|
||||
negative_prompt_embeds_t5=None,
|
||||
negative_prompt_embeds_llama3=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds_t5 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_4 is not None and prompt_embeds_llama3 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is None and prompt_embeds_t5 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined."
|
||||
)
|
||||
elif prompt is None and prompt_embeds_llama3 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)):
|
||||
raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}")
|
||||
|
||||
if negative_prompt is not None and negative_pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:"
|
||||
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:"
|
||||
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:"
|
||||
f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:"
|
||||
f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None:
|
||||
if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`"
|
||||
f" {negative_pooled_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None:
|
||||
if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`"
|
||||
f" {negative_prompt_embeds_t5.shape}."
|
||||
)
|
||||
if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None:
|
||||
if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`"
|
||||
f" {negative_prompt_embeds_llama3.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -706,10 +542,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
@@ -718,7 +552,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -816,22 +649,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
[`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated. images.
|
||||
"""
|
||||
|
||||
prompt_embeds = kwargs.get("prompt_embeds", None)
|
||||
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
|
||||
|
||||
if prompt_embeds is not None:
|
||||
deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
|
||||
deprecate("prompt_embeds", "0.34.0", deprecation_message)
|
||||
prompt_embeds_t5 = prompt_embeds[0]
|
||||
prompt_embeds_llama3 = prompt_embeds[1]
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
|
||||
deprecate("negative_prompt_embeds", "0.34.0", deprecation_message)
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@@ -841,25 +658,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
scale = math.sqrt(scale)
|
||||
width, height = int(width * scale // division * division), int(height * scale // division * division)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
prompt_4,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
negative_prompt_4=negative_prompt_4,
|
||||
prompt_embeds_t5=prompt_embeds_t5,
|
||||
prompt_embeds_llama3=prompt_embeds_llama3,
|
||||
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
|
||||
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
@@ -869,18 +667,17 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
elif pooled_prompt_embeds is not None:
|
||||
batch_size = pooled_prompt_embeds.shape[0]
|
||||
elif prompt_embeds is not None:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
|
||||
else:
|
||||
batch_size = 1
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode prompt
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
(
|
||||
prompt_embeds_t5,
|
||||
negative_prompt_embeds_t5,
|
||||
prompt_embeds_llama3,
|
||||
negative_prompt_embeds_llama3,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
@@ -893,10 +690,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
negative_prompt_4=negative_prompt_4,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds_t5=prompt_embeds_t5,
|
||||
prompt_embeds_llama3=prompt_embeds_llama3,
|
||||
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
|
||||
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
@@ -906,8 +701,13 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0)
|
||||
prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1)
|
||||
prompt_embeds_arr = []
|
||||
for n, p in zip(negative_prompt_embeds, prompt_embeds):
|
||||
if len(n.shape) == 3:
|
||||
prompt_embeds_arr.append(torch.cat([n, p], dim=0))
|
||||
else:
|
||||
prompt_embeds_arr.append(torch.cat([n, p], dim=1))
|
||||
prompt_embeds = prompt_embeds_arr
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -923,6 +723,26 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
latents,
|
||||
)
|
||||
|
||||
if latents.shape[-2] != latents.shape[-1]:
|
||||
B, C, H, W = latents.shape
|
||||
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
|
||||
|
||||
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
|
||||
img_ids = torch.zeros(pH, pW, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
|
||||
img_ids = img_ids.reshape(pH * pW, -1)
|
||||
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
|
||||
img_ids_pad[: pH * pW, :] = img_ids
|
||||
|
||||
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
|
||||
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
|
||||
if self.do_classifier_free_guidance:
|
||||
img_sizes = img_sizes.repeat(2 * B, 1)
|
||||
img_ids = img_ids.repeat(2 * B, 1, 1)
|
||||
else:
|
||||
img_sizes = img_ids = None
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(self.transformer.max_seq)
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
@@ -954,9 +774,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timesteps=timestep,
|
||||
encoder_hidden_states_t5=prompt_embeds_t5,
|
||||
encoder_hidden_states_llama3=prompt_embeds_llama3,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_embeds=pooled_prompt_embeds,
|
||||
img_sizes=img_sizes,
|
||||
img_ids=img_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = -noise_pred
|
||||
@@ -982,9 +803,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5)
|
||||
prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3)
|
||||
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -344,7 +344,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
)
|
||||
prompt_embeds = self.text_encoder(
|
||||
**expanded_inputs,
|
||||
pixel_values=image_embeds,
|
||||
pixel_value=image_embeds,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
@@ -404,11 +404,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
|
||||
if is_loaded_in_8bit_bnb:
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
||||
or hasattr(module._hf_hook, "hooks")
|
||||
|
||||
@@ -380,7 +380,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
last_image: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
@@ -399,16 +398,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
image = image.unsqueeze(2)
|
||||
if last_image is None:
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
else:
|
||||
last_image = last_image.unsqueeze(2)
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
|
||||
dim=2,
|
||||
)
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
video_condition = video_condition.to(device=device, dtype=dtype)
|
||||
|
||||
latents_mean = (
|
||||
@@ -432,11 +424,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
|
||||
if last_image is None:
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
else:
|
||||
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
@@ -488,7 +476,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
last_image: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -633,10 +620,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
else:
|
||||
image_embeds = self.encode_image([image, last_image], device)
|
||||
image_embeds = self.encode_image(image, device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
@@ -647,10 +631,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.z_dim
|
||||
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
|
||||
if last_image is not None:
|
||||
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
||||
device, dtype=torch.float32
|
||||
)
|
||||
latents, condition = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
@@ -662,7 +642,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
last_image,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
|
||||
@@ -171,11 +171,9 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc
|
||||
|
||||
if cls_name == "Params4bit":
|
||||
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
|
||||
msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
|
||||
if dtype:
|
||||
msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
|
||||
output_tensor = output_tensor.to(dtype)
|
||||
logger.warning_once(msg)
|
||||
logger.warning_once(
|
||||
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
if state.SCB is None:
|
||||
|
||||
@@ -129,7 +129,6 @@ from .state_dict_utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
state_dict_all_zero,
|
||||
)
|
||||
from .testing_utils import is_accelerator_device
|
||||
from .typing_utils import _get_detailed_type, _is_valid_type
|
||||
|
||||
|
||||
|
||||
@@ -1289,18 +1289,6 @@ if is_torch_available():
|
||||
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64
|
||||
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
||||
a proper `torch.device`.
|
||||
"""
|
||||
if device == "disk":
|
||||
return False
|
||||
else:
|
||||
return torch.device(device).type not in ["meta", "cpu"]
|
||||
|
||||
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
|
||||
|
||||
# Type definition of key used in `Expectations` class.
|
||||
|
||||
@@ -43,7 +43,7 @@ enable_full_determinism()
|
||||
|
||||
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = HiDreamImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
@@ -24,11 +24,9 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
LlamaConfig,
|
||||
LlamaTokenizerFast,
|
||||
LlavaConfig,
|
||||
LlavaForConditionalGeneration,
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from transformers.models.clip import CLIPVisionConfig
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
@@ -118,7 +116,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
text_config = LlamaConfig(
|
||||
llama_text_encoder_config = LlamaConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=16,
|
||||
@@ -126,21 +124,11 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=100,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
vision_config = CLIPVisionConfig(
|
||||
hidden_size=8,
|
||||
intermediate_size=37,
|
||||
projection_dim=32,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=2,
|
||||
image_size=224,
|
||||
)
|
||||
llava_text_encoder_config = LlavaConfig(vision_config, text_config, pad_token_id=100, image_token_index=101)
|
||||
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
@@ -156,8 +144,8 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = LlavaForConditionalGeneration(llava_text_encoder_config)
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
|
||||
text_encoder = LlamaModel(llama_text_encoder_config)
|
||||
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
|
||||
@@ -165,14 +153,14 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_processor = CLIPImageProcessor(
|
||||
crop_size=224,
|
||||
crop_size=336,
|
||||
do_center_crop=True,
|
||||
do_normalize=True,
|
||||
do_resize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
resample=3,
|
||||
size=224,
|
||||
size=336,
|
||||
)
|
||||
|
||||
components = {
|
||||
@@ -202,10 +190,6 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
"prompt_template": {
|
||||
"template": "{}",
|
||||
"crop_start": 0,
|
||||
"image_emb_len": 49,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 54,
|
||||
"double_return_token_id": 0,
|
||||
},
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
@@ -213,7 +197,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 64,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -160,90 +160,3 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
|
||||
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# TODO: impl FlowDPMSolverMultistepScheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
pos_embed_seq_len=2 * (4 * 4 + 1),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
projection_dim=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
image_size=4,
|
||||
intermediate_size=16,
|
||||
patch_size=1,
|
||||
)
|
||||
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_processor = CLIPImageProcessor(crop_size=4, size=4)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image_height = 16
|
||||
image_width = 16
|
||||
image = Image.new("RGB", (image_width, image_height))
|
||||
last_image = Image.new("RGB", (image_width, image_height))
|
||||
inputs = {
|
||||
"image": image,
|
||||
"last_image": last_image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative",
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -523,15 +523,13 @@ class SlowBnb8bitTests(Base8bitTests):
|
||||
torch_dtype=torch.float16,
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
# CUDA device placement works.
|
||||
device = torch_device if torch_device != "rocm" else "cuda"
|
||||
pipeline_8bit = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
transformer=transformer_8bit,
|
||||
text_encoder_3=text_encoder_3_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
).to(device)
|
||||
).to("cuda")
|
||||
|
||||
# Check if inference works.
|
||||
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
from diffusers.utils import is_torch_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -36,9 +30,9 @@ if is_torch_available():
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def get_memory_consumption_stat(model, inputs):
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model(**inputs)
|
||||
max_mem_allocated = backend_max_memory_allocated(torch_device)
|
||||
return max_mem_allocated
|
||||
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
|
||||
return max_memory_mem_allocated
|
||||
|
||||
+4
-56
@@ -123,13 +123,11 @@ def check_pipeline_doc(overwrite=False):
|
||||
|
||||
# sort sub pipeline docs
|
||||
for pipeline_doc in pipeline_docs:
|
||||
if "sections" in pipeline_doc:
|
||||
sub_pipeline_doc = pipeline_doc["sections"]
|
||||
if "section" in pipeline_doc:
|
||||
sub_pipeline_doc = pipeline_doc["section"]
|
||||
new_sub_pipeline_doc = clean_doc_toc(sub_pipeline_doc)
|
||||
if new_sub_pipeline_doc != sub_pipeline_doc:
|
||||
diff = True
|
||||
if overwrite:
|
||||
pipeline_doc["sections"] = new_sub_pipeline_doc
|
||||
if overwrite:
|
||||
pipeline_doc["section"] = new_sub_pipeline_doc
|
||||
new_pipeline_docs.append(pipeline_doc)
|
||||
|
||||
# sort overall pipeline doc
|
||||
@@ -151,55 +149,6 @@ def check_pipeline_doc(overwrite=False):
|
||||
)
|
||||
|
||||
|
||||
def check_model_doc(overwrite=False):
|
||||
with open(PATH_TO_TOC, encoding="utf-8") as f:
|
||||
content = yaml.safe_load(f.read())
|
||||
|
||||
# Get to the API doc
|
||||
api_idx = 0
|
||||
while content[api_idx]["title"] != "API":
|
||||
api_idx += 1
|
||||
api_doc = content[api_idx]["sections"]
|
||||
|
||||
# Then to the model doc
|
||||
model_idx = 0
|
||||
while api_doc[model_idx]["title"] != "Models":
|
||||
model_idx += 1
|
||||
|
||||
diff = False
|
||||
model_docs = api_doc[model_idx]["sections"]
|
||||
new_model_docs = []
|
||||
|
||||
# sort sub model docs
|
||||
for model_doc in model_docs:
|
||||
if "sections" in model_doc:
|
||||
sub_model_doc = model_doc["sections"]
|
||||
new_sub_model_doc = clean_doc_toc(sub_model_doc)
|
||||
if new_sub_model_doc != sub_model_doc:
|
||||
diff = True
|
||||
if overwrite:
|
||||
model_doc["sections"] = new_sub_model_doc
|
||||
new_model_docs.append(model_doc)
|
||||
|
||||
# sort overall model doc
|
||||
new_model_docs = clean_doc_toc(new_model_docs)
|
||||
|
||||
if new_model_docs != model_docs:
|
||||
diff = True
|
||||
if overwrite:
|
||||
api_doc[model_idx]["sections"] = new_model_docs
|
||||
|
||||
if diff:
|
||||
if overwrite:
|
||||
content[api_idx]["sections"] = api_doc
|
||||
with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
|
||||
f.write(yaml.dump(content, allow_unicode=True))
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model doc part of the table of content is not properly sorted, run `make style` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
@@ -207,4 +156,3 @@ if __name__ == "__main__":
|
||||
|
||||
check_scheduler_doc(args.fix_and_overwrite)
|
||||
check_pipeline_doc(args.fix_and_overwrite)
|
||||
check_model_doc(args.fix_and_overwrite)
|
||||
|
||||
@@ -100,7 +100,7 @@ if __name__ == "__main__":
|
||||
"doc_path": "docs/source/en/api/loaders/lora.md",
|
||||
"src_path": "src/diffusers/loaders/lora_pipeline.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
|
||||
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user