Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c4262eea9b | |||
| 36f7eaa994 | |||
| 7464205c47 | |||
| 634aa8d4a7 | |||
| 6399810347 | |||
| b439e5db6d | |||
| adde0818fd | |||
| 5d45b0ce9d |
@@ -418,8 +418,6 @@ jobs:
|
||||
test_location: "gguf"
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -173,8 +173,6 @@
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/quanto
|
||||
title: quanto
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
|
||||
@@ -31,11 +31,6 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
|
||||
## QuantoConfig
|
||||
|
||||
[[autodoc]] QuantoConfig
|
||||
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
@@ -36,6 +36,5 @@ Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
- [Quanto](./quanto.md)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# Quanto
|
||||
|
||||
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
|
||||
|
||||
- All features are available in eager mode (works with non-traceable models)
|
||||
- Supports quantization aware training
|
||||
- Quantized models are compatible with `torch.compile`
|
||||
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
|
||||
|
||||
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
|
||||
|
||||
```shell
|
||||
pip install optimum-quanto accelerate
|
||||
```
|
||||
|
||||
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Skipping Quantization on specific modules
|
||||
|
||||
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Using `from_single_file` with the Quanto Backend
|
||||
|
||||
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Saving Quantized models
|
||||
|
||||
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
|
||||
|
||||
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
|
||||
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# save quantized model to reuse
|
||||
transformer.save_pretrained("<your quantized model save path>")
|
||||
|
||||
# you can reload your quantized model with
|
||||
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
|
||||
```
|
||||
|
||||
## Using `torch.compile` with Quanto
|
||||
|
||||
Currently the Quanto backend supports `torch.compile` for the following quantization types:
|
||||
|
||||
- `int8` weights
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="int8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch_dtype
|
||||
)
|
||||
pipe.to("cuda")
|
||||
images = pipe("A cat holding a sign that says hello").images[0]
|
||||
images.save("flux-quanto-compile.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
### Weights
|
||||
|
||||
- float8
|
||||
- int8
|
||||
- int4
|
||||
- int2
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@@ -10,7 +10,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|
||||
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
|
||||
@@ -94,55 +93,6 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
|
||||
|
||||
## Example usages
|
||||
|
||||
### Spatiotemporal Skip Guidance
|
||||
|
||||
**Junha Hyung\*, Kinam Kim\*, Susung Hong, Min-Jung Kim, Jaegul Choo**
|
||||
|
||||
**KAIST AI, University of Washington**
|
||||
|
||||
[*Spatiotemporal Skip Guidance (STG) for Enhanced Video Diffusion Sampling*](https://arxiv.org/abs/2411.18664) (CVPR 2025) is a simple training-free sampling guidance method for enhancing transformer-based video diffusion models. STG employs an implicit weak model via self-perturbation, avoiding the need for external models or additional training. By selectively skipping spatiotemporal layers, STG produces an aligned, degraded version of the original model to boost sample quality without compromising diversity or dynamic degree.
|
||||
|
||||
Following is the example video of STG applied to Mochi.
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/148adb59-da61-4c50-9dfa-425dcb5c23b3
|
||||
|
||||
More examples and information can be found on the [GitHub repository](https://github.com/junhahyung/STGuidance) and the [Project website](https://junhahyung.github.io/STGuidance/).
|
||||
|
||||
#### Usage example
|
||||
```python
|
||||
import torch
|
||||
from pipeline_stg_mochi import MochiSTGPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Enable memory savings
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
#--------Option--------#
|
||||
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
stg_applied_layers_idx = [34]
|
||||
stg_mode = "STG"
|
||||
stg_scale = 1.0 # 0.0 for CFG
|
||||
#----------------------#
|
||||
|
||||
# Generate video frames
|
||||
frames = pipe(
|
||||
prompt,
|
||||
height=480,
|
||||
width=480,
|
||||
num_frames=81,
|
||||
stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
stg_scale=stg_scale,
|
||||
generator = torch.Generator().manual_seed(42),
|
||||
do_rescaling=do_rescaling,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
### Adaptive Mask Inpainting
|
||||
|
||||
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
|
||||
|
||||
@@ -1,876 +0,0 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_cogvideox import CogVideoXSTGPipeline
|
||||
|
||||
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
>>> pipe = CogVideoXSTGPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A father and son building a treehouse together, their hands covered in sawdust and smiles on their faces, realistic style."
|
||||
... )
|
||||
>>> pipe.transformer.to(memory_format=torch.channels_last)
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [11] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set to 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> # Generate video frames with STG parameters
|
||||
>>> frames = pipe(
|
||||
... prompt=prompt,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(frames, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogVideoX uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CogVideoXTransformer3DModel`]):
|
||||
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
)
|
||||
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae_scaling_factor_image * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
p = self.transformer.config.patch_size
|
||||
p_t = self.transformer.config.patch_size_t
|
||||
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
|
||||
if p_t is None:
|
||||
# CogVideoX 1.0
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# CogVideoX 1.5
|
||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
grid_type="slice",
|
||||
max_size=(base_size_height, base_size_width),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [11],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
||||
num_frames (`int`, defaults to `48`):
|
||||
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
||||
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
||||
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
||||
needs to be satisfied is that of divisibility mentioned above.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
||||
num_frames = num_frames or self.transformer.config.sample_frames
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents
|
||||
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
||||
patch_size_t = self.transformer.config.patch_size_t
|
||||
additional_frames = 0
|
||||
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
||||
additional_frames = patch_size_t - latent_frames % patch_size_t
|
||||
num_frames += additional_frames * self.vae_scale_factor_temporal
|
||||
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CogVideoXPipelineOutput(frames=video)
|
||||
@@ -1,794 +0,0 @@
|
||||
# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import HunyuanVideoLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
|
||||
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from diffusers import HunyuanVideoTransformer3DModel
|
||||
>>> from examples.community.pipeline_stg_hunyuan_video import HunyuanVideoSTGPipeline
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = HunyuanVideoSTGPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
|
||||
>>> pipe.vae.enable_tiling()
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [2] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt="A wolf howling at the moon, with the moon subtly resembling a giant clock face, realistic style.",
|
||||
... height=320,
|
||||
... width=512,
|
||||
... num_frames=61,
|
||||
... num_inference_steps=30,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = {
|
||||
"template": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
),
|
||||
"crop_start": 95,
|
||||
}
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def forward_without_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using HunyuanVideo.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`LlamaModel`]):
|
||||
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
tokenizer (`LlamaTokenizer`):
|
||||
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
transformer ([`HunyuanVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder_2 ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: LlamaModel,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
transformer: HunyuanVideoTransformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder_2: CLIPTextModel,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_llama_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_template: Dict[str, Any],
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
num_hidden_layers_to_skip: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = [prompt_template["template"].format(p) for p in prompt]
|
||||
|
||||
crop_start = prompt_template.get("crop_start", None)
|
||||
if crop_start is None:
|
||||
prompt_template_input = self.tokenizer(
|
||||
prompt_template["template"],
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
crop_start = prompt_template_input["input_ids"].shape[-1]
|
||||
# Remove <|eot_id|> token and placeholder {}
|
||||
crop_start -= 2
|
||||
|
||||
max_sequence_length += crop_start
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
max_length=max_sequence_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device=device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
prompt_embeds = prompt_embeds[:, crop_start:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 77,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder_2.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]] = None,
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
||||
prompt,
|
||||
prompt_template,
|
||||
num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
if prompt_2 is None and pooled_prompt_embeds is None:
|
||||
prompt_2 = prompt
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt,
|
||||
num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=77,
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_template=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if prompt_template is not None:
|
||||
if not isinstance(prompt_template, dict):
|
||||
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
|
||||
if "template" not in prompt_template:
|
||||
raise ValueError(
|
||||
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: 32,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Union[str, List[str]] = None,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
max_sequence_length: int = 256,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [2],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
|
||||
CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
|
||||
not applied.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
|
||||
where the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_template=prompt_template,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
||||
if pooled_prompt_embeds is not None:
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_latent_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare guidance condition
|
||||
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_without_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
noise_pred_perturb = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred + self._stg_scale * (noise_pred - noise_pred_perturb)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return HunyuanVideoPipelineOutput(frames=video)
|
||||
@@ -1,886 +0,0 @@
|
||||
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_ltx import LTXSTGPipeline
|
||||
|
||||
>>> pipe = LTXSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> video = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=704,
|
||||
... height=480,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=50,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
Args:
|
||||
transformer ([`LTXVideoTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLLTXVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: LTXVideoTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
self.transformer_temporal_patch_size = (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
||||
)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
width = width // self.vae_spatial_compression_ratio
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
frame_rate: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [19],
|
||||
stg_scale: Optional[float] = 1.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, defaults to `704`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `161`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `3 `):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `128 `):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -1,985 +0,0 @@
|
||||
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
>>> from examples.community.pipeline_stg_ltx_image2video import LTXImageToVideoSTGPipeline
|
||||
|
||||
>>> pipe = LTXImageToVideoSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/11.png"
|
||||
>>> )
|
||||
>>> prompt = "A medieval fantasy scene featuring a rugged man with shoulder-length brown hair and a beard. He wears a dark leather tunic over a maroon shirt with intricate metal details. His facial expression is serious and intense, and he is making a gesture with his right hand, forming a small circle with his thumb and index finger. The warm golden lighting casts dramatic shadows on his face. The background includes an ornate stone arch and blurred medieval-style decor, creating an epic atmosphere."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> video = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=704,
|
||||
... height=480,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=50,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling,
|
||||
>>> ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
Args:
|
||||
transformer ([`LTXVideoTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLLTXVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: LTXVideoTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
self.transformer_temporal_patch_size = (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
||||
)
|
||||
|
||||
self.default_height = 512
|
||||
self.default_width = 704
|
||||
self.default_frames = 121
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
width = width // self.vae_spatial_compression_ratio
|
||||
num_frames = (
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
|
||||
)
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype), conditioning_mask
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
|
||||
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
|
||||
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
return latents, conditioning_mask
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
frame_rate: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [19],
|
||||
stg_scale: Optional[float] = 1.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, defaults to `704`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `161`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `3 `):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `128 `):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._stg_scale = stg_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image = image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents, conditioning_mask = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
||||
rope_interpolation_scale = (
|
||||
1 / latent_frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
timestep, _ = timestep.chunk(2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
timestep, _, _ = timestep.chunk(3)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = self._unpack_latents(
|
||||
noise_pred,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
noise_latents = latents[:, :, 1:]
|
||||
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
|
||||
|
||||
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -1,843 +0,0 @@
|
||||
# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import Mochi1LoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
|
||||
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from examples.community.pipeline_stg_mochi import MochiSTGPipeline
|
||||
|
||||
>>> pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> pipe.enable_vae_tiling()
|
||||
>>> prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [34] # Layer indices from 0 to 41
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
>>> do_rescaling = False
|
||||
|
||||
>>> frames = pipe(
|
||||
... prompt=prompt,
|
||||
... num_inference_steps=28,
|
||||
... guidance_scale=3.5,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... do_rescaling=do_rescaling).frames[0]
|
||||
>>> export_to_video(frames, "mochi.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states_ptb = hidden_states[2:]
|
||||
encoder_hidden_states_ptb = encoder_hidden_states[2:]
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
|
||||
if not self.context_pre_only:
|
||||
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, temb
|
||||
)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
|
||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
|
||||
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
)
|
||||
norm_encoder_hidden_states = self.norm3_context(
|
||||
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
|
||||
)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
||||
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||
)
|
||||
|
||||
hidden_states[2:] = hidden_states_ptb
|
||||
encoder_hidden_states[2:] = encoder_hidden_states_ptb
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
if linear_steps is None:
|
||||
linear_steps = num_steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||||
quadratic_steps = num_steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
||||
const = quadratic_coef * (linear_steps**2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom value")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
r"""
|
||||
The mochi pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://github.com/genmoai/models
|
||||
|
||||
Args:
|
||||
transformer ([`MochiTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLMochi`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLMochi,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: MochiTransformer3DModel,
|
||||
force_zeros_for_empty_prompt: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
# TODO: determine these scaling factors from model parameters
|
||||
self.vae_spatial_scale_factor = 8
|
||||
self.vae_temporal_scale_factor = 6
|
||||
self.patch_size = 2
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
|
||||
)
|
||||
self.default_height = 480
|
||||
self.default_width = 848
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
# The original Mochi implementation zeros out empty negative prompts
|
||||
# but this can lead to overflow when placing the entire pipeline under the autocast context
|
||||
# adding this here so that we can enable zeroing prompts if necessary
|
||||
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
|
||||
text_input_ids = torch.zeros_like(text_input_ids, device=device)
|
||||
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = height // self.vae_spatial_scale_factor
|
||||
width = width // self.vae_spatial_scale_factor
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
|
||||
latents = latents.to(dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 19,
|
||||
num_inference_steps: int = 64,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [34],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
do_rescaling: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to `self.default_height`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to `self.default_width`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `19`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, defaults to `4.5`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to `256`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
|
||||
is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.default_height
|
||||
width = width or self.default_width
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for i in stg_applied_layers_idx:
|
||||
self.transformer.transformer_blocks[i].forward = types.MethodType(
|
||||
forward_with_stg, self.transformer.transformer_blocks[i]
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat(
|
||||
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
|
||||
)
|
||||
|
||||
# 5. Prepare timestep
|
||||
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
threshold_noise = 0.025
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
||||
sigmas = np.array(sigmas)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
|
||||
# to make sure we're using the correct non-reversed timestep value.
|
||||
self._current_timestep = 1000 - t
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# Mochi CFG + Sampling runs in FP32
|
||||
noise_pred = noise_pred.to(torch.float32)
|
||||
|
||||
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
|
||||
if do_rescaling:
|
||||
rescaling_scale = 0.7
|
||||
factor = noise_pred_text.std() / noise_pred.std()
|
||||
factor = rescaling_scale * factor + (1 - rescaling_scale)
|
||||
noise_pred = noise_pred * factor
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return MochiPipelineOutput(frames=video)
|
||||
@@ -1,32 +0,0 @@
|
||||
# AnyTextPipeline Pipeline
|
||||
|
||||
Project page: https://aigcdesigngroup.github.io/homepage_anytext
|
||||
|
||||
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
|
||||
|
||||
Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from anytext_controlnet import AnyTextControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# I chose a font file shared by an HF staff:
|
||||
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
||||
|
||||
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
||||
variant="fp16",)
|
||||
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
||||
controlnet=anytext_controlnet, torch_dtype=torch.float16,
|
||||
trust_remote_code=False, # One needs to give permission to run this pipeline's code
|
||||
).to("cuda")
|
||||
|
||||
# generate image
|
||||
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
||||
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
|
||||
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,463 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
|
||||
# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
|
||||
# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
|
||||
#
|
||||
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.controlnets.controlnet import (
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class AnyTextControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
glyph_channels=1,
|
||||
position_channels=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.glyph_block = nn.Sequential(
|
||||
nn.Conv2d(glyph_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_block = nn.Sequential(
|
||||
nn.Conv2d(position_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 64, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
|
||||
|
||||
def forward(self, glyphs, positions, text_info):
|
||||
glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
|
||||
position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
|
||||
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
|
||||
|
||||
return guided_hint
|
||||
|
||||
|
||||
class AnyTextControlNetModel(ControlNetModel):
|
||||
"""
|
||||
A AnyTextControlNetModel model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 1,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
conditioning_channels,
|
||||
flip_sin_to_cos,
|
||||
freq_shift,
|
||||
down_block_types,
|
||||
mid_block_type,
|
||||
only_cross_attention,
|
||||
block_out_channels,
|
||||
layers_per_block,
|
||||
downsample_padding,
|
||||
mid_block_scale_factor,
|
||||
act_fn,
|
||||
norm_num_groups,
|
||||
norm_eps,
|
||||
cross_attention_dim,
|
||||
transformer_layers_per_block,
|
||||
encoder_hid_dim,
|
||||
encoder_hid_dim_type,
|
||||
attention_head_dim,
|
||||
num_attention_heads,
|
||||
use_linear_projection,
|
||||
class_embed_type,
|
||||
addition_embed_type,
|
||||
addition_time_embed_dim,
|
||||
num_class_embeds,
|
||||
upcast_attention,
|
||||
resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels,
|
||||
global_pool_conditions,
|
||||
addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
glyph_channels=conditioning_channels,
|
||||
position_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`~PromptDiffusionControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
#controlnet_cond (`torch.Tensor`):
|
||||
# The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
# elif channel_order == "bgr":
|
||||
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.models.controlnet.zero_module
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
@@ -1,209 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .RecSVTR import Block
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Im2Im(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Im2Seq(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# assert H == 1
|
||||
x = x.reshape(B, C, H * W)
|
||||
x = x.permute((0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
class EncoderWithRNN(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(EncoderWithRNN, self).__init__()
|
||||
hidden_size = kwargs.get("hidden_size", 256)
|
||||
self.out_channels = hidden_size * 2
|
||||
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
|
||||
class SequenceEncoder(nn.Module):
|
||||
def __init__(self, in_channels, encoder_type="rnn", **kwargs):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
self.encoder_reshape = Im2Seq(in_channels)
|
||||
self.out_channels = self.encoder_reshape.out_channels
|
||||
self.encoder_type = encoder_type
|
||||
if encoder_type == "reshape":
|
||||
self.only_reshape = True
|
||||
else:
|
||||
support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
|
||||
assert encoder_type in support_encoder_dict, "{} must in {}".format(
|
||||
encoder_type, support_encoder_dict.keys()
|
||||
)
|
||||
|
||||
self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
|
||||
self.out_channels = self.encoder.out_channels
|
||||
self.only_reshape = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.encoder_type != "svtr":
|
||||
x = self.encoder_reshape(x)
|
||||
if not self.only_reshape:
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
else:
|
||||
x = self.encoder(x)
|
||||
x = self.encoder_reshape(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = Swish()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class EncoderWithSVTR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
dims=64, # XS
|
||||
depth=2,
|
||||
hidden_dims=120,
|
||||
use_guide=False,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
mlp_ratio=2.0,
|
||||
drop_rate=0.1,
|
||||
attn_drop_rate=0.1,
|
||||
drop_path=0.0,
|
||||
qk_scale=None,
|
||||
):
|
||||
super(EncoderWithSVTR, self).__init__()
|
||||
self.depth = depth
|
||||
self.use_guide = use_guide
|
||||
self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
|
||||
self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
|
||||
|
||||
self.svtr_block = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=hidden_dims,
|
||||
num_heads=num_heads,
|
||||
mixer="Global",
|
||||
HW=None,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer="swish",
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path,
|
||||
norm_layer="nn.LayerNorm",
|
||||
epsilon=1e-05,
|
||||
prenorm=False,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
|
||||
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
|
||||
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
|
||||
self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
|
||||
|
||||
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
|
||||
self.out_channels = dims
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# weight initialization
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# for use guide
|
||||
if self.use_guide:
|
||||
z = x.clone()
|
||||
z.stop_gradient = True
|
||||
else:
|
||||
z = x
|
||||
# for short cut
|
||||
h = z
|
||||
# reduce dim
|
||||
z = self.conv1(z)
|
||||
z = self.conv2(z)
|
||||
# SVTR global block
|
||||
B, C, H, W = z.shape
|
||||
z = z.flatten(2).permute(0, 2, 1)
|
||||
|
||||
for blk in self.svtr_block:
|
||||
z = blk(z)
|
||||
|
||||
z = self.norm(z)
|
||||
# last stage
|
||||
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
|
||||
z = self.conv3(z)
|
||||
z = torch.cat((h, z), dim=1)
|
||||
z = self.conv1x1(self.conv4(z))
|
||||
|
||||
return z
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
svtrRNN = EncoderWithSVTR(56)
|
||||
print(svtrRNN)
|
||||
@@ -1,45 +0,0 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CTCHead(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
|
||||
):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
bias=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = {}
|
||||
result["ctc"] = predicts
|
||||
result["ctc_neck"] = x
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
return result
|
||||
@@ -1,49 +0,0 @@
|
||||
from torch import nn
|
||||
|
||||
from .RecCTCHead import CTCHead
|
||||
from .RecMv1_enhance import MobileNetV1Enhance
|
||||
from .RNN import Im2Im, Im2Seq, SequenceEncoder
|
||||
|
||||
|
||||
backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
|
||||
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
|
||||
head_dict = {"CTCHead": CTCHead}
|
||||
|
||||
|
||||
class RecModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert "in_channels" in config, "in_channels must in model config"
|
||||
backbone_type = config["backbone"].pop("type")
|
||||
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
|
||||
self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
|
||||
|
||||
neck_type = config["neck"].pop("type")
|
||||
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
|
||||
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
|
||||
|
||||
head_type = config["head"].pop("type")
|
||||
assert head_type in head_dict, f"head.type must in {head_dict}"
|
||||
self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
|
||||
|
||||
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
|
||||
|
||||
def load_3rd_state_dict(self, _3rd_name, _state):
|
||||
self.backbone.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.neck.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.head.load_3rd_state_dict(_3rd_name, _state)
|
||||
|
||||
def forward(self, x):
|
||||
import torch
|
||||
|
||||
x = x.to(torch.float32)
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head.ctc_encoder(x)
|
||||
return x
|
||||
@@ -1,197 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .common import Activation
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
|
||||
):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.act = act
|
||||
self._conv = nn.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self._batch_norm = nn.BatchNorm2d(
|
||||
num_filters,
|
||||
)
|
||||
if self.act is not None:
|
||||
self._act = Activation(act_type=act, inplace=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
if self.act is not None:
|
||||
y = self._act(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
|
||||
):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale),
|
||||
)
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Module):
|
||||
def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
|
||||
)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
|
||||
)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False,
|
||||
)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True,
|
||||
)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=last_conv_stride,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale,
|
||||
)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
if last_pool_type == "avg":
|
||||
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
def hardsigmoid(x):
|
||||
return F.relu6(x + 3.0, inplace=True) / 6.0
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
x = torch.mul(inputs, outputs)
|
||||
|
||||
return x
|
||||
@@ -1,570 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional
|
||||
from torch.nn.init import ones_, trunc_normal_, zeros_
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0.0, training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = torch.tensor(1 - drop_prob)
|
||||
shape = (x.size()[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
|
||||
random_tensor = torch.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
if isinstance(act_layer, str):
|
||||
self.act = Swish()
|
||||
else:
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvMixer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
HW=(8, 25),
|
||||
local_k=(3, 3),
|
||||
):
|
||||
super().__init__()
|
||||
self.HW = HW
|
||||
self.dim = dim
|
||||
self.local_mixer = nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
local_k,
|
||||
1,
|
||||
(local_k[0] // 2, local_k[1] // 2),
|
||||
groups=num_heads,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.HW[0]
|
||||
w = self.HW[1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
|
||||
x = self.local_mixer(x)
|
||||
x = x.flatten(2).transpose([0, 2, 1])
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
mixer="Global",
|
||||
HW=(8, 25),
|
||||
local_k=(7, 11),
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.HW = HW
|
||||
if HW is not None:
|
||||
H = HW[0]
|
||||
W = HW[1]
|
||||
self.N = H * W
|
||||
self.C = dim
|
||||
if mixer == "Local" and HW is not None:
|
||||
hk = local_k[0]
|
||||
wk = local_k[1]
|
||||
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
|
||||
for h in range(0, H):
|
||||
for w in range(0, W):
|
||||
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
|
||||
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
|
||||
mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
|
||||
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
|
||||
self.mask = mask[None, None, :]
|
||||
# self.mask = mask.unsqueeze([0, 1])
|
||||
self.mixer = mixer
|
||||
|
||||
def forward(self, x):
|
||||
if self.HW is not None:
|
||||
N = self.N
|
||||
C = self.C
|
||||
else:
|
||||
_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = q.matmul(k.permute((0, 1, 3, 2)))
|
||||
if self.mixer == "Local":
|
||||
attn += self.mask
|
||||
attn = functional.softmax(attn, dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mixer="Global",
|
||||
local_mixer=(7, 11),
|
||||
HW=(8, 25),
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer="nn.LayerNorm",
|
||||
epsilon=1e-6,
|
||||
prenorm=True,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
if mixer == "Global" or mixer == "Local":
|
||||
self.mixer = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
mixer=mixer,
|
||||
HW=HW,
|
||||
local_k=local_mixer,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
elif mixer == "Conv":
|
||||
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
|
||||
else:
|
||||
raise TypeError("The mixer must be one of [Global, Local, Conv]")
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.prenorm = prenorm
|
||||
|
||||
def forward(self, x):
|
||||
if self.prenorm:
|
||||
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
|
||||
super().__init__()
|
||||
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
|
||||
self.img_size = img_size
|
||||
self.num_patches = num_patches
|
||||
self.embed_dim = embed_dim
|
||||
self.norm = None
|
||||
if sub_num == 2:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
)
|
||||
if sub_num == 3:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 4,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class SubSample(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
|
||||
super().__init__()
|
||||
self.types = types
|
||||
if types == "Pool":
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.proj = nn.Linear(in_channels, out_channels)
|
||||
else:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
self.norm = eval(sub_norm)(out_channels)
|
||||
if act is not None:
|
||||
self.act = act()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.types == "Pool":
|
||||
x1 = self.avgpool(x)
|
||||
x2 = self.maxpool(x)
|
||||
x = (x1 + x2) * 0.5
|
||||
out = self.proj(x.flatten(2).permute((0, 2, 1)))
|
||||
else:
|
||||
x = self.conv(x)
|
||||
out = x.flatten(2).permute((0, 2, 1))
|
||||
out = self.norm(out)
|
||||
if self.act is not None:
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SVTRNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=[48, 100],
|
||||
in_channels=3,
|
||||
embed_dim=[64, 128, 256],
|
||||
depth=[3, 6, 3],
|
||||
num_heads=[2, 4, 8],
|
||||
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
|
||||
local_mixer=[[7, 11], [7, 11], [7, 11]],
|
||||
patch_merging="Conv", # Conv, Pool, None
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
last_drop=0.1,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer="nn.LayerNorm",
|
||||
sub_norm="nn.LayerNorm",
|
||||
epsilon=1e-6,
|
||||
out_channels=192,
|
||||
out_char_num=25,
|
||||
block_unit="Block",
|
||||
act="nn.GELU",
|
||||
last_stage=True,
|
||||
sub_num=2,
|
||||
prenorm=True,
|
||||
use_lenhead=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.embed_dim = embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.prenorm = prenorm
|
||||
patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
|
||||
# self.pos_embed = self.create_parameter(
|
||||
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
|
||||
|
||||
# self.add_parameter("pos_embed", self.pos_embed)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
Block_unit = eval(block_unit)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, sum(depth))
|
||||
self.blocks1 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[0],
|
||||
num_heads=num_heads[0],
|
||||
mixer=mixer[0 : depth[0]][i],
|
||||
HW=self.HW,
|
||||
local_mixer=local_mixer[0],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[0 : depth[0]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[0])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample1 = SubSample(
|
||||
embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
||||
)
|
||||
HW = [self.HW[0] // 2, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.patch_merging = patch_merging
|
||||
self.blocks2 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[1],
|
||||
num_heads=num_heads[1],
|
||||
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[1],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[1])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample2 = SubSample(
|
||||
embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
||||
)
|
||||
HW = [self.HW[0] // 4, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.blocks3 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[2],
|
||||
num_heads=num_heads[2],
|
||||
mixer=mixer[depth[0] + depth[1] :][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[2],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] + depth[1] :][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
for i in range(depth[2])
|
||||
]
|
||||
)
|
||||
self.last_stage = last_stage
|
||||
if last_stage:
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
|
||||
self.last_conv = nn.Conv2d(
|
||||
in_channels=embed_dim[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = nn.Dropout(p=last_drop)
|
||||
if not prenorm:
|
||||
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
|
||||
self.use_lenhead = use_lenhead
|
||||
if use_lenhead:
|
||||
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
|
||||
self.hardswish_len = nn.Hardswish()
|
||||
self.dropout_len = nn.Dropout(p=last_drop)
|
||||
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks1:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
|
||||
for blk in self.blocks2:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
|
||||
for blk in self.blocks3:
|
||||
x = blk(x)
|
||||
if not self.prenorm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if self.use_lenhead:
|
||||
len_x = self.len_conv(x.mean(1))
|
||||
len_x = self.dropout_len(self.hardswish_len(len_x))
|
||||
if self.last_stage:
|
||||
if self.patch_merging is not None:
|
||||
h = self.HW[0] // 4
|
||||
else:
|
||||
h = self.HW[0]
|
||||
x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
if self.use_lenhead:
|
||||
return x, len_x
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = torch.rand(1, 3, 48, 100)
|
||||
svtr = SVTRNet()
|
||||
|
||||
out = svtr(a)
|
||||
print(svtr)
|
||||
print(out.size())
|
||||
@@ -1,74 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Hswish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hswish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
|
||||
|
||||
|
||||
# out = max(0, min(1, slop*x+offset))
|
||||
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
|
||||
class Hsigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hsigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
|
||||
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(GELU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace:
|
||||
x.mul_(torch.sigmoid(x))
|
||||
return x
|
||||
else:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Activation(nn.Module):
|
||||
def __init__(self, act_type, inplace=True):
|
||||
super(Activation, self).__init__()
|
||||
act_type = act_type.lower()
|
||||
if act_type == "relu":
|
||||
self.act = nn.ReLU(inplace=inplace)
|
||||
elif act_type == "relu6":
|
||||
self.act = nn.ReLU6(inplace=inplace)
|
||||
elif act_type == "sigmoid":
|
||||
raise NotImplementedError
|
||||
elif act_type == "hard_sigmoid":
|
||||
self.act = Hsigmoid(inplace)
|
||||
elif act_type == "hard_swish":
|
||||
self.act = Hswish(inplace=inplace)
|
||||
elif act_type == "leakyrelu":
|
||||
self.act = nn.LeakyReLU(inplace=inplace)
|
||||
elif act_type == "gelu":
|
||||
self.act = GELU(inplace=inplace)
|
||||
elif act_type == "swish":
|
||||
self.act = Swish(inplace=inplace)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.act(inputs)
|
||||
@@ -1,95 +0,0 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
|
||||
@@ -128,10 +128,6 @@ _deps = [
|
||||
"GitPython<3.1.19",
|
||||
"scipy",
|
||||
"onnx",
|
||||
"optimum_quanto>=0.2.6",
|
||||
"gguf>=0.10.0",
|
||||
"torchao>=0.7.0",
|
||||
"bitsandbytes>=0.43.3",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
@@ -239,11 +235,6 @@ extras["test"] = deps_list(
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
|
||||
extras["gguf"] = deps_list("gguf", "accelerate")
|
||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
|
||||
@@ -6,19 +6,14 @@ from .utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_optimum_quanto_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -37,7 +32,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -59,54 +54,6 @@ _import_structure = {
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_bitsandbytes_objects
|
||||
|
||||
_import_structure["utils.dummy_bitsandbytes_objects"] = [
|
||||
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_gguf_objects
|
||||
|
||||
_import_structure["utils.dummy_gguf_objects"] = [
|
||||
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torchao_objects
|
||||
|
||||
_import_structure["utils.dummy_torchao_objects"] = [
|
||||
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_optimum_quanto_objects
|
||||
|
||||
_import_structure["utils.dummy_optimum_quanto_objects"] = [
|
||||
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -652,38 +599,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
|
||||
try:
|
||||
if not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_bitsandbytes_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||
|
||||
try:
|
||||
if not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_gguf_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import GGUFQuantizationConfig
|
||||
|
||||
try:
|
||||
if not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torchao_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_optimum_quanto_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import QuantoConfig
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
|
||||
@@ -35,10 +35,6 @@ deps = {
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"scipy": "scipy",
|
||||
"onnx": "onnx",
|
||||
"optimum_quanto": "optimum_quanto>=0.2.6",
|
||||
"gguf": "gguf>=0.10.0",
|
||||
"torchao": "torchao>=0.7.0",
|
||||
"bitsandbytes": "bitsandbytes>=0.43.3",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
|
||||
@@ -29,16 +29,11 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||
|
||||
# Always use memory-efficient CPU offloading to minimize RAM usage
|
||||
|
||||
_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
@@ -61,6 +56,7 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -72,8 +68,12 @@ class ModuleGroup:
|
||||
self.buffers = buffers
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
@@ -82,125 +82,23 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
# Use the most efficient module-level transfer when possible
|
||||
# This approach mirrors how PyTorch handles full model transfers
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Only onload if some parameters are not on the target device
|
||||
if any(p.device != self.onload_device for p in group_module.parameters()):
|
||||
try:
|
||||
# Try the most efficient approach using _apply
|
||||
if hasattr(group_module, "_apply"):
|
||||
# This is what module.to() uses internally
|
||||
def to_device(t):
|
||||
if t.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
return t.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
return t.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
return t
|
||||
|
||||
# Apply to all tensors without unnecessary copies
|
||||
group_module._apply(to_device)
|
||||
else:
|
||||
# Fallback to direct parameter transfer
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
except Exception as e:
|
||||
# If optimization fails, fall back to direct parameter transfer
|
||||
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle explicit parameters
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle buffers
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
buffer.data = buffer.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
# For CPU offloading
|
||||
if self.offload_device.type == "cpu":
|
||||
# Synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Empty GPU cache before offloading to reduce memory fragmentation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For module groups, use a single, unified approach that is closest to
|
||||
# the behavior of model.to("cpu")
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Check if we need to offload this module
|
||||
if any(p.device.type != "cpu" for p in group_module.parameters()):
|
||||
# Use PyTorch's built-in to() method directly, which preserves
|
||||
# memory mapping when moving to CPU
|
||||
try:
|
||||
# Non-blocking=False for CPU transfers, as it ensures memory is
|
||||
# immediately available and potentially preserves memory mapping
|
||||
group_module.to("cpu", non_blocking=False)
|
||||
except Exception as e:
|
||||
# If there's any error, fall back to parameter-level offloading
|
||||
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
|
||||
for param in group_module.parameters():
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle explicit parameters - move directly to CPU with non-blocking=False
|
||||
# which can preserve memory mapping in some PyTorch versions
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device.type != "cpu":
|
||||
buffer.data = buffer.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Let Python's normal reference counting handle cleanup
|
||||
# We don't force garbage collection to avoid slowing down inference
|
||||
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
else:
|
||||
# For non-CPU offloading, synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# For non-CPU offloading, use the regular approach
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
@@ -210,9 +108,6 @@ class ModuleGroup:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
# After offloading, we can unpin the memory if configured to do so
|
||||
# We'll keep it pinned by default for better performance
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
@@ -234,7 +129,6 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
# Offload to CPU
|
||||
self.group.offload_()
|
||||
return module
|
||||
|
||||
@@ -419,8 +313,7 @@ def apply_group_offloading(
|
||||
If True, offloading and onloading is done with non-blocking data transfer.
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
|
||||
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
|
||||
overlapping computation and data transfer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -451,19 +344,12 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
# We no longer need a pinned group manager as we're not using pinned memory
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module,
|
||||
num_blocks_per_group,
|
||||
offload_device,
|
||||
onload_device,
|
||||
non_blocking,
|
||||
stream,
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
@@ -498,7 +384,12 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# We no longer need a CPU parameter dictionary
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -520,6 +411,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -556,6 +448,7 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -590,7 +483,12 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# We no longer need a CPU parameter dictionary
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -605,6 +503,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -649,6 +548,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -667,6 +567,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -70,7 +70,6 @@ if is_torch_available():
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"CogView4LoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
@@ -104,7 +103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
|
||||
@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.info(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
|
||||
@@ -298,15 +298,19 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
@@ -555,26 +559,31 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
@@ -729,15 +738,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -830,11 +843,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
||||
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
@@ -1072,33 +1085,43 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
||||
if len(transformer_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
@@ -1187,11 +1210,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
@@ -1240,6 +1262,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
if text_encoder_2_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
@@ -1249,7 +1272,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
||||
@@ -1293,7 +1315,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -1307,7 +1328,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
@@ -1518,23 +1539,18 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.get(k)
|
||||
for k in list(state_dict.keys())
|
||||
if k.startswith(f"{self.transformer_name}.") and "lora" in k
|
||||
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
|
||||
}
|
||||
transformer_norm_state_dict = {
|
||||
k: state_dict.pop(k)
|
||||
for k in list(state_dict.keys())
|
||||
if k.startswith(f"{self.transformer_name}.")
|
||||
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
||||
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
||||
}
|
||||
|
||||
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
has_param_with_expanded_shape = False
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
||||
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
||||
)
|
||||
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
||||
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
||||
)
|
||||
|
||||
if has_param_with_expanded_shape:
|
||||
logger.info(
|
||||
@@ -1542,22 +1558,20 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
||||
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
||||
)
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
for k in transformer_lora_state_dict:
|
||||
state_dict.update({k: transformer_lora_state_dict[k]})
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
transformer_lora_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if len(transformer_norm_state_dict) > 0:
|
||||
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
||||
transformer_norm_state_dict,
|
||||
@@ -1565,16 +1579,18 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
discard_original_layers=False,
|
||||
)
|
||||
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
@@ -1607,14 +1623,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_norm_into_transformer(
|
||||
@@ -2153,14 +2172,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -2811,7 +2833,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -2855,7 +2876,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -3116,7 +3136,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3160,7 +3179,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -3421,7 +3439,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3465,7 +3482,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -3729,7 +3745,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3773,7 +3788,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
@@ -4406,311 +4420,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`CogView4Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -54,7 +54,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -193,6 +192,11 @@ class PeftAdapterMixin:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
try:
|
||||
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
|
||||
except ImportError:
|
||||
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -236,7 +240,10 @@ class PeftAdapterMixin:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
keys = list(state_dict.keys())
|
||||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
||||
if len(model_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
@@ -254,16 +261,22 @@ class PeftAdapterMixin:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
rank[key] = val.shape[1]
|
||||
# Support to handle cases where layer patterns are treated as full layer names
|
||||
# was added later in PEFT. So, we handle it accordingly.
|
||||
# TODO: when we fix the minimal PEFT version for Diffusers,
|
||||
# we should remove `_maybe_adjust_config()`.
|
||||
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
|
||||
else:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -353,11 +366,6 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.info(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
def save_lora_adapter(
|
||||
self,
|
||||
save_directory,
|
||||
|
||||
@@ -37,7 +37,6 @@ from .single_file_utils import (
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
@@ -120,10 +119,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
|
||||
@@ -117,12 +117,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"sana": [
|
||||
"blocks.0.cross_attn.q_linear.weight",
|
||||
"blocks.0.cross_attn.q_linear.bias",
|
||||
"blocks.0.cross_attn.kv_linear.weight",
|
||||
"blocks.0.cross_attn.kv_linear.bias",
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
}
|
||||
@@ -184,7 +178,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
||||
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
||||
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
@@ -676,9 +669,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
model_type = "lumina2"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
|
||||
model_type = "sana"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
||||
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
||||
target_key = "model.diffusion_model.patch_embedding.weight"
|
||||
@@ -2907,111 +2897,6 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
|
||||
|
||||
# Positional and patch embeddings.
|
||||
checkpoint.pop("pos_embed")
|
||||
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Timestep embeddings.
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
|
||||
|
||||
# Caption Projection.
|
||||
checkpoint.pop("y_embedder.y_embedding")
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
|
||||
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
|
||||
|
||||
for i in range(num_layers):
|
||||
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
|
||||
f"blocks.{i}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Self-Attention
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Cross-Attention
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.q_linear.bias"
|
||||
)
|
||||
|
||||
linear_sample_k, linear_sample_v = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
|
||||
)
|
||||
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
|
||||
|
||||
# Output Projections
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
|
||||
f"blocks.{i}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
|
||||
# Final layer
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
|
||||
@@ -716,6 +716,11 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Store normalization parameters as tensors
|
||||
self.mean = torch.tensor(latents_mean)
|
||||
self.std = torch.tensor(latents_std)
|
||||
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
|
||||
|
||||
self.z_dim = z_dim
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
@@ -747,6 +752,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = self.scale.type_as(x)
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
@@ -765,6 +771,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
enc = self.quant_conv(out)
|
||||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
enc = torch.cat([mu, logvar], dim=1)
|
||||
self.clear_cache()
|
||||
return enc
|
||||
@@ -791,8 +799,10 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.post_quant_conv(z)
|
||||
@@ -826,7 +836,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
decoded = self._decode(z).sample
|
||||
scale = self.scale.type_as(z)
|
||||
decoded = self._decode(z, scale).sample
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
|
||||
@@ -245,9 +245,6 @@ def load_model_dict_into_meta(
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
# For quantizers have save weights using torch.float8_e4m3fn
|
||||
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
|
||||
pass
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
@@ -295,9 +292,7 @@ def load_model_dict_into_meta(
|
||||
elif is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||
):
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
||||
)
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
@@ -195,7 +195,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
||||
|
||||
|
||||
@@ -12,21 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...utils import logging
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -289,7 +288,7 @@ class CogView4RotaryPosEmbed(nn.Module):
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
@@ -384,24 +383,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
original_size: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
@@ -436,10 +419,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -22,7 +22,6 @@ from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import CogView4LoraLoaderMixin
|
||||
from ...models import AutoencoderKL, CogView4Transformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -134,7 +133,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
class CogView4Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using CogView4.
|
||||
|
||||
@@ -393,10 +392,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -418,7 +413,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
@@ -532,7 +526,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# Default call parameters
|
||||
@@ -622,7 +615,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -635,7 +627,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
passed_components = passed_components or []
|
||||
if folder_names:
|
||||
if folder_names is not None:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# extract all components of the pipeline and their associated files
|
||||
@@ -141,25 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
return True
|
||||
|
||||
|
||||
def filter_model_files(filenames):
|
||||
"""Filter model repo files for just files/folders that contain model weights"""
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
@@ -187,10 +169,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
|
||||
legacy_variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
@@ -199,68 +177,54 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
|
||||
if not ignore_patterns:
|
||||
return filenames
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
|
||||
# interested in the extension name
|
||||
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
|
||||
# Group files by component
|
||||
components = {}
|
||||
for filename in filenames:
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
def find_component(filename):
|
||||
if not len(filename.split("/")) == 2:
|
||||
components.setdefault("", []).append(filename)
|
||||
return
|
||||
component = filename.split("/")[0]
|
||||
return component
|
||||
|
||||
def has_sharded_variant(component, variant, variant_filenames):
|
||||
# If component exists check for sharded variant index filename
|
||||
# If component doesn't exist check main dir for sharded variant index filename
|
||||
component = component + "/" if component else ""
|
||||
variant_index_re = re.compile(
|
||||
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
|
||||
|
||||
for filename in non_variant_filenames:
|
||||
if convert_to_variant(filename) in variant_filenames:
|
||||
continue
|
||||
|
||||
component, _ = filename.split("/")
|
||||
components.setdefault(component, []).append(filename)
|
||||
component = find_component(filename)
|
||||
# If a sharded variant exists skip adding to allowed patterns
|
||||
if has_sharded_variant(component, variant, variant_filenames):
|
||||
continue
|
||||
|
||||
usable_filenames = set()
|
||||
variant_filenames = set()
|
||||
for component, component_filenames in components.items():
|
||||
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
|
||||
|
||||
component_variants = set()
|
||||
component_legacy_variants = set()
|
||||
component_non_variants = set()
|
||||
if variant is not None:
|
||||
component_variants = filter_with_regex(component_filenames, variant_file_re)
|
||||
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
|
||||
|
||||
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
|
||||
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
|
||||
|
||||
if component_variants or component_legacy_variants:
|
||||
variant_filenames.update(
|
||||
component_variants | component_variant_index_files
|
||||
if component_variants
|
||||
else component_legacy_variants | component_legacy_variant_index_files
|
||||
)
|
||||
|
||||
else:
|
||||
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
|
||||
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
|
||||
|
||||
usable_filenames.update(component_non_variants | component_variant_index_files)
|
||||
|
||||
usable_filenames.update(variant_filenames)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
|
||||
raise ValueError(error_message)
|
||||
|
||||
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
usable_filenames.add(filename)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
@@ -958,6 +922,10 @@ def _get_custom_components_and_folders(
|
||||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
return custom_components, folder_names
|
||||
|
||||
|
||||
@@ -965,6 +933,7 @@ def _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names: List[str],
|
||||
model_filenames: List[str],
|
||||
variant_filenames: List[str],
|
||||
use_safetensors: bool,
|
||||
from_flax: bool,
|
||||
allow_pickle: bool,
|
||||
@@ -995,6 +964,16 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
@@ -1002,6 +981,16 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
|
||||
f"your folder structure."
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ from .pipeline_loading_utils import (
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
filter_model_files,
|
||||
load_sub_model,
|
||||
maybe_raise_or_warn,
|
||||
variant_compatible_siblings,
|
||||
@@ -1388,8 +1387,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
|
||||
use_safetensors = use_safetensors if use_safetensors is not None else True
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
allow_patterns = None
|
||||
ignore_patterns = None
|
||||
@@ -1404,18 +1405,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
if not local_files_only:
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
)
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
||||
warn_msg = (
|
||||
@@ -1430,20 +1419,61 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
filenames = set(filenames) - set(ignore_filenames)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.22.0"):
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
||||
|
||||
custom_components, folder_names = _get_custom_components_and_folders(
|
||||
pretrained_model_name, config_dict, filenames, variant
|
||||
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
||||
)
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
|
||||
custom_pipeline = config_dict["_class_name"][0]
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
# also allow downloading generation_config.json of the transformers model
|
||||
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
|
||||
load_components_from_hub = len(custom_components) > 0
|
||||
|
||||
@@ -1476,15 +1506,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
||||
passed_components = [k for k in expected_components if k in kwargs]
|
||||
|
||||
# retrieve the names of the folders containing model weights
|
||||
model_folder_names = {
|
||||
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
|
||||
}
|
||||
# retrieve all patterns that should not be downloaded and error out when needed
|
||||
ignore_patterns = _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names,
|
||||
filenames,
|
||||
model_filenames,
|
||||
variant_filenames,
|
||||
use_safetensors,
|
||||
from_flax,
|
||||
allow_pickle,
|
||||
@@ -1493,29 +1520,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant,
|
||||
)
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
# Don't download any objects that are passed
|
||||
allow_patterns = [
|
||||
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
|
||||
|
||||
@@ -563,15 +563,6 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
|
||||
@@ -392,17 +392,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
@@ -665,15 +654,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
|
||||
@@ -26,10 +26,8 @@ from .quantization_config import (
|
||||
GGUFQuantizationConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
)
|
||||
from .quanto import QuantoQuantizer
|
||||
from .torchao import TorchAoHfQuantizer
|
||||
|
||||
|
||||
@@ -37,7 +35,6 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
|
||||
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
|
||||
"gguf": GGUFQuantizer,
|
||||
"quanto": QuantoQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
}
|
||||
|
||||
@@ -45,7 +42,6 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"bitsandbytes_4bit": BitsAndBytesConfig,
|
||||
"bitsandbytes_8bit": BitsAndBytesConfig,
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -135,7 +135,6 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
import bitsandbytes as bnb
|
||||
|
||||
@@ -446,7 +445,6 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
@@ -108,7 +108,6 @@ class GGUFQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Optional[Dict[str, Any]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
|
||||
@@ -45,7 +45,6 @@ class QuantizationMethod(str, Enum):
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
QUANTO = "quanto"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -687,38 +686,3 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
return (
|
||||
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantoConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `quanto`.
|
||||
|
||||
Args:
|
||||
weights_dtype (`str`, *optional*, defaults to `"int8"`):
|
||||
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights_dtype: str = "int8",
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights_dtype = weights_dtype
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
if self.weights_dtype not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .quanto_quantizer import QuantoQuantizer
|
||||
@@ -1,177 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||
|
||||
from ...utils import (
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_optimum_quanto_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import CustomDtype, set_module_tensor_to_device
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from .utils import _replace_with_quanto_layers
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuantoQuantizer(DiffusersQuantizer):
|
||||
r"""
|
||||
Diffusers Quantizer for Optimum Quanto
|
||||
"""
|
||||
|
||||
use_keep_in_fp32_modules = True
|
||||
requires_calibration = False
|
||||
required_packages = ["quanto", "accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_optimum_quanto_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||
)
|
||||
if not is_optimum_quanto_version(">=", "0.2.6"):
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
|
||||
"Please upgrade your installation with `pip install --upgrade optimum-quanto"
|
||||
)
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
|
||||
)
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
raise ValueError(
|
||||
"`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
|
||||
)
|
||||
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
# Quanto imports diffusers internally. This is here to prevent circular imports
|
||||
from optimum.quanto import QModuleMixin, QTensor
|
||||
from optimum.quanto.tensor.packed import PackedTensor
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
|
||||
return True
|
||||
elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
|
||||
return not module.frozen
|
||||
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create the quantized parameter by calling .freeze() after setting it to the module.
|
||||
"""
|
||||
|
||||
dtype = kwargs.get("dtype", torch.float32)
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized:
|
||||
setattr(module, tensor_name, param_value)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
|
||||
module.freeze()
|
||||
module.weight.requires_grad = False
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if is_accelerate_version(">=", "0.27.0"):
|
||||
mapping = {
|
||||
"int8": torch.int8,
|
||||
"float8": CustomDtype.FP8,
|
||||
"int4": CustomDtype.INT4,
|
||||
"int2": CustomDtype.INT2,
|
||||
}
|
||||
target_dtype = mapping[self.quantization_config.weights_dtype]
|
||||
|
||||
return target_dtype
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
# Quanto imports diffusers internally. This is here to prevent circular imports
|
||||
from optimum.quanto import QModuleMixin
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, QModuleMixin):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
and not missing.endswith(".weight")
|
||||
and not missing.endswith(".bias")
|
||||
):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
||||
model = _replace_with_quanto_layers(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
pre_quantized=self.pre_quantized,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return True
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
|
||||
# Quanto imports diffusers internally. These are placed here to avoid circular imports
|
||||
from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
|
||||
|
||||
def _get_weight_type(dtype: str):
|
||||
return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
|
||||
|
||||
def _replace_layers(model, quantization_config, modules_to_not_convert):
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return model
|
||||
|
||||
for name, module in model.named_children():
|
||||
_replace_layers(module, quantization_config, modules_to_not_convert)
|
||||
|
||||
if name in modules_to_not_convert:
|
||||
continue
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
with init_empty_weights():
|
||||
qlinear = QLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
dtype=module.weight.dtype,
|
||||
weights=_get_weight_type(quantization_config.weights_dtype),
|
||||
)
|
||||
model._modules[name] = qlinear
|
||||
model._modules[name].source_cls = type(module)
|
||||
model._modules[name].requires_grad_(False)
|
||||
|
||||
return model
|
||||
|
||||
model = _replace_layers(model, quantization_config, modules_to_not_convert)
|
||||
has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
|
||||
" Please check your model architecture, or submit an issue on Github if you think this is a bug."
|
||||
" https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
# We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
|
||||
# to match when trying to load weights with load_model_dict_into_meta
|
||||
if pre_quantized:
|
||||
freeze(model)
|
||||
|
||||
return model
|
||||
@@ -23,14 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...utils import (
|
||||
get_module_from_name,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
logging,
|
||||
)
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
@@ -69,43 +62,6 @@ if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
|
||||
def _update_torch_safe_globals():
|
||||
safe_globals = [
|
||||
(torch.uint1, "torch.uint1"),
|
||||
(torch.uint2, "torch.uint2"),
|
||||
(torch.uint3, "torch.uint3"),
|
||||
(torch.uint4, "torch.uint4"),
|
||||
(torch.uint5, "torch.uint5"),
|
||||
(torch.uint6, "torch.uint6"),
|
||||
(torch.uint7, "torch.uint7"),
|
||||
]
|
||||
try:
|
||||
from torchao.dtypes import NF4Tensor
|
||||
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
||||
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
||||
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
||||
|
||||
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
||||
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
||||
)
|
||||
logger.debug(e)
|
||||
|
||||
finally:
|
||||
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
||||
|
||||
|
||||
if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -259,7 +215,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: List[str],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
|
||||
|
||||
@@ -79,8 +79,6 @@ from .import_utils import (
|
||||
is_matplotlib_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_optimum_quanto_available,
|
||||
is_optimum_quanto_version,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_safetensors_available,
|
||||
@@ -94,7 +92,6 @@ from .import_utils import (
|
||||
is_torch_xla_available,
|
||||
is_torch_xla_version,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
is_torchsde_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class BitsAndBytesConfig(metaclass=DummyObject):
|
||||
_backends = ["bitsandbytes"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["bitsandbytes"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["bitsandbytes"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["bitsandbytes"])
|
||||
@@ -1,17 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GGUFQuantizationConfig(metaclass=DummyObject):
|
||||
_backends = ["gguf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["gguf"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["gguf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["gguf"])
|
||||
@@ -1,17 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class QuantoConfig(metaclass=DummyObject):
|
||||
_backends = ["optimum_quanto"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["optimum_quanto"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["optimum_quanto"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["optimum_quanto"])
|
||||
@@ -1,17 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class TorchAoConfig(metaclass=DummyObject):
|
||||
_backends = ["torchao"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchao"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torchao"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torchao"])
|
||||
@@ -365,15 +365,6 @@ if _is_torchao_available:
|
||||
_is_torchao_available = False
|
||||
|
||||
|
||||
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||
if _is_optimum_quanto_available:
|
||||
try:
|
||||
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
|
||||
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_optimum_quanto_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -502,10 +493,6 @@ def is_torchao_available():
|
||||
return _is_torchao_available
|
||||
|
||||
|
||||
def is_optimum_quanto_available():
|
||||
return _is_optimum_quanto_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -649,11 +636,6 @@ TORCHAO_IMPORT_ERROR = """
|
||||
torchao`
|
||||
"""
|
||||
|
||||
QUANTO_IMPORT_ERROR = """
|
||||
{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
|
||||
install optimum-quanto`
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@@ -681,7 +663,6 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
|
||||
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
||||
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
||||
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -868,21 +849,6 @@ def is_gguf_version(operation: str, version: str):
|
||||
return compare_versions(parse(_gguf_version), operation, version)
|
||||
|
||||
|
||||
def is_torchao_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current torchao version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _is_torchao_available:
|
||||
return False
|
||||
return compare_versions(parse(_torchao_version), operation, version)
|
||||
|
||||
|
||||
def is_k_diffusion_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current k-diffusion version to a given reference with an operation.
|
||||
@@ -898,21 +864,6 @@ def is_k_diffusion_version(operation: str, version: str):
|
||||
return compare_versions(parse(_k_diffusion_version), operation, version)
|
||||
|
||||
|
||||
def is_optimum_quanto_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _is_optimum_quanto_available:
|
||||
return False
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
@staticmethod
|
||||
def from_pretrained(*args, **kwargs):
|
||||
return AutoTokenizer.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True
|
||||
)
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogView4Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 4,
|
||||
"num_attention_heads": 4,
|
||||
"out_channels": 4,
|
||||
"text_embed_dim": 32,
|
||||
"time_embed_dim": 8,
|
||||
"condition_dim": 4,
|
||||
}
|
||||
transformer_cls = CogView4Transformer2DModel
|
||||
vae_kwargs = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
"sample_size": 128,
|
||||
}
|
||||
vae_cls = AutoencoderKL
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = (
|
||||
TokenizerWrapper,
|
||||
"hf-internal-testing/tiny-random-cogview4",
|
||||
"tokenizer",
|
||||
)
|
||||
text_encoder_cls, text_encoder_id, text_encoder_subfolder = (
|
||||
GlmModel,
|
||||
"hf-internal-testing/tiny-random-cogview4",
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 4
|
||||
sizes = (4, 4)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "",
|
||||
"num_inference_steps": 1,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": sequence_length,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
@@ -371,8 +371,9 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
in cap_logger.out
|
||||
cap_logger.out.startswith(
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
)
|
||||
)
|
||||
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
|
||||
|
||||
@@ -391,7 +392,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
|
||||
self.assertTrue(
|
||||
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
|
||||
)
|
||||
|
||||
def test_lora_parameter_expanded_shapes(self):
|
||||
|
||||
@@ -1948,50 +1948,6 @@ class PeftLoraLoaderMixinTests:
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def test_logs_info_when_no_lora_keys_found(self):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
# Skip text encoder check for now as that is handled with `transformers`.
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(no_op_state_dict)
|
||||
out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
|
||||
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
|
||||
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
|
||||
|
||||
# test only for text encoder
|
||||
for lora_module in self.pipeline_class._lora_loadable_modules:
|
||||
if "text_encoder" in lora_module:
|
||||
text_encoder = getattr(pipe, lora_module)
|
||||
if lora_module == "text_encoder":
|
||||
prefix = "text_encoder"
|
||||
elif lora_module == "text_encoder_2":
|
||||
prefix = "text_encoder_2"
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.pipeline_class.load_lora_into_text_encoder(
|
||||
no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
|
||||
)
|
||||
|
||||
def test_set_adapters_match_attention_kwargs(self):
|
||||
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
|
||||
@@ -212,7 +212,6 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
|
||||
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
def test_only_non_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -223,13 +222,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_only_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -240,13 +236,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
non_variant_file = "text_encoder/model.safetensors"
|
||||
filenames = [
|
||||
@@ -256,27 +249,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_non_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"model.{variant}.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
@@ -286,76 +275,23 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
non_variant_file = "model.safetensors"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"model.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
"diffusion_pytorch_model.safetensors.index.json",
|
||||
"diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_non_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
@@ -366,13 +302,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
||||
assert all(variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
||||
@@ -383,49 +316,10 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
assert model_filenames == variant_filenames
|
||||
|
||||
def test_single_variant_with_sharded_non_variant_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_sharded_mixed_variants_downloaded(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "unet"
|
||||
filenames = [
|
||||
@@ -441,144 +335,9 @@ class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_downloading_when_no_variant_exists(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
|
||||
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
def test_downloading_use_safetensors_false(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
filenames = [
|
||||
"text_encoder/model.bin",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
assert all(".safetensors" not in f for f in model_filenames)
|
||||
|
||||
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
allowed_non_variant = "diffusion_pytorch_model.safetensors"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_download_variants_when_component_has_no_safetensors_variant(self):
|
||||
ignore_patterns = None
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"unet/diffusion_pytorch_model.{variant}.bin",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert {
|
||||
f"unet/diffusion_pytorch_model.{variant}.bin",
|
||||
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
||||
} == model_filenames
|
||||
|
||||
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
|
||||
ignore_patterns = ["*.bin"]
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
|
||||
]
|
||||
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
allowed_non_variant = "unet"
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors.index.json",
|
||||
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
||||
|
||||
def test_download_sharded_legacy_variants(self):
|
||||
ignore_patterns = None
|
||||
variant = "fp16"
|
||||
filenames = [
|
||||
f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
|
||||
"vae/diffusion_pytorch_model.safetensors.index.json",
|
||||
f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
|
||||
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||
f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=variant, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert all(variant in f for f in model_filenames)
|
||||
|
||||
def test_download_onnx_models(self):
|
||||
ignore_patterns = ["*.safetensors"]
|
||||
filenames = [
|
||||
"vae/model.onnx",
|
||||
"unet/model.onnx",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert model_filenames == set(filenames)
|
||||
|
||||
def test_download_flax_models(self):
|
||||
ignore_patterns = ["*.safetensors", "*.bin"]
|
||||
filenames = [
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
]
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(
|
||||
filenames, variant=None, ignore_patterns=ignore_patterns
|
||||
)
|
||||
assert model_filenames == set(filenames)
|
||||
|
||||
|
||||
class ProgressBarTests(unittest.TestCase):
|
||||
def get_dummy_components_image_generation(self):
|
||||
|
||||
@@ -54,8 +54,29 @@ if is_transformers_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
@@ -75,8 +96,6 @@ class Base4bitTests(unittest.TestCase):
|
||||
# This was obtained on audace so the number might slightly change
|
||||
expected_rel_difference = 3.69
|
||||
|
||||
expected_memory_saving_ratio = 0.8
|
||||
|
||||
prompt = "a beautiful sunset amidst the mountains."
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
@@ -121,10 +140,8 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(self, "model_fp16"):
|
||||
del self.model_fp16
|
||||
if hasattr(self, "model_4bit"):
|
||||
del self.model_4bit
|
||||
del self.model_fp16
|
||||
del self.model_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -163,32 +180,6 @@ class BnB4BitBasicTests(Base4bitTests):
|
||||
linear = get_some_linear_layer(self.model_4bit)
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
# Delete to not let anything interfere.
|
||||
del self.model_4bit, self.model_fp16
|
||||
|
||||
# Re-instantiate.
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = {
|
||||
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
|
||||
}
|
||||
model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
|
||||
del model_fp16
|
||||
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
model_4bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
|
||||
)
|
||||
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
|
||||
@@ -60,8 +60,29 @@ if is_transformers_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
@@ -81,8 +102,6 @@ class Base8bitTests(unittest.TestCase):
|
||||
# This was obtained on audace so the number might slightly change
|
||||
expected_rel_difference = 1.94
|
||||
|
||||
expected_memory_saving_ratio = 0.7
|
||||
|
||||
prompt = "a beautiful sunset amidst the mountains."
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
@@ -123,10 +142,8 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(self, "model_fp16"):
|
||||
del self.model_fp16
|
||||
if hasattr(self, "model_8bit"):
|
||||
del self.model_8bit
|
||||
del self.model_fp16
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -165,28 +182,6 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
linear = get_some_linear_layer(self.model_8bit)
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
# Delete to not let anything interfere.
|
||||
del self.model_8bit, self.model_fp16
|
||||
|
||||
# Re-instantiate.
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = {
|
||||
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
|
||||
}
|
||||
model_fp16 = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
|
||||
del model_fp16
|
||||
|
||||
config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
|
||||
)
|
||||
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
|
||||
|
||||
def test_original_dtype(self):
|
||||
r"""
|
||||
A simple test to check if the model succesfully stores the original dtype
|
||||
@@ -253,7 +248,7 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
self.assertTrue(linear.weight.dtype == torch.int8)
|
||||
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
|
||||
|
||||
self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
|
||||
self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
|
||||
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
|
||||
|
||||
def test_config_from_pretrained(self):
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import is_optimum_quanto_available, is_torch_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_accelerate
|
||||
class QuantoBaseTesterMixin:
|
||||
model_id = None
|
||||
pipeline_model_id = None
|
||||
model_cls = None
|
||||
torch_dtype = torch.bfloat16
|
||||
# the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
|
||||
expected_memory_reduction = 0.0
|
||||
keep_in_fp32_module = ""
|
||||
modules_to_not_convert = ""
|
||||
_test_torch_compile = False
|
||||
|
||||
def setUp(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def tearDown(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "float8"}
|
||||
|
||||
def get_dummy_model_init_kwargs(self):
|
||||
return {
|
||||
"pretrained_model_name_or_path": self.model_id,
|
||||
"torch_dtype": self.torch_dtype,
|
||||
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
|
||||
}
|
||||
|
||||
def test_quanto_layers(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert isinstance(module, QLinear)
|
||||
|
||||
def test_quanto_memory_usage(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = {
|
||||
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
|
||||
}
|
||||
|
||||
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
|
||||
unquantized_model.to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
|
||||
|
||||
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
quantized_model.to(torch_device)
|
||||
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
|
||||
|
||||
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
|
||||
|
||||
def test_keep_modules_in_fp32(self):
|
||||
r"""
|
||||
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
|
||||
Also ensures if inference works.
|
||||
"""
|
||||
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
|
||||
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
model.to("cuda")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if name in model._keep_in_fp32_modules:
|
||||
assert module.weight.dtype == torch.float32
|
||||
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
init_kwargs = self.get_dummy_model_init_kwargs()
|
||||
|
||||
quantization_config_kwargs = self.get_dummy_init_kwargs()
|
||||
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
|
||||
quantization_config = QuantoConfig(**quantization_config_kwargs)
|
||||
|
||||
init_kwargs.update({"quantization_config": quantization_config})
|
||||
|
||||
model = self.model_cls.from_pretrained(**init_kwargs)
|
||||
model.to("cuda")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in self.modules_to_not_convert:
|
||||
assert not isinstance(module, QLinear)
|
||||
|
||||
def test_dtype_assignment(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype`
|
||||
model.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device` and `dtype`
|
||||
model.to(device="cuda:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
model.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
model.half()
|
||||
|
||||
# This should work
|
||||
model.to("cuda")
|
||||
|
||||
def test_serialization(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
model_output = model(**inputs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
saved_model = self.model_cls.from_pretrained(
|
||||
tmp_dir,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
saved_model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
saved_model_output = saved_model(**inputs)
|
||||
|
||||
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_torch_compile(self):
|
||||
if not self._test_torch_compile:
|
||||
return
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
model_output = model(**self.get_dummy_inputs()).sample
|
||||
|
||||
compiled_model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
|
||||
|
||||
model_output = model_output.detach().float().cpu().numpy()
|
||||
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_device_map_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = self.model_cls.from_pretrained(
|
||||
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
|
||||
)
|
||||
|
||||
|
||||
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
|
||||
model_id = "hf-internal-testing/tiny-flux-transformer"
|
||||
model_cls = FluxTransformer2DModel
|
||||
pipeline_cls = FluxPipeline
|
||||
torch_dtype = torch.bfloat16
|
||||
keep_in_fp32_module = "proj_out"
|
||||
modules_to_not_convert = ["proj_out"]
|
||||
_test_torch_compile = False
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
(1, 512, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"pooled_projections": torch.randn(
|
||||
(1, 768),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
def get_dummy_training_inputs(self, device=None, seed: int = 0):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"txt_ids": text_ids,
|
||||
"img_ids": image_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def test_model_cpu_offload(self):
|
||||
init_kwargs = self.get_dummy_init_kwargs()
|
||||
transformer = self.model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
quantization_config=QuantoConfig(**init_kwargs),
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = self.pipeline_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
|
||||
|
||||
def test_training(self):
|
||||
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
|
||||
quantized_model = self.model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
|
||||
for param in quantized_model.parameters():
|
||||
# freeze the model as only adapter layers will be trained
|
||||
param.requires_grad = False
|
||||
if param.ndim == 1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
for _, module in quantized_model.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
module.to_q = LoRALayer(module.to_q, rank=4)
|
||||
module.to_k = LoRALayer(module.to_k, rank=4)
|
||||
module.to_v = LoRALayer(module.to_v, rank=4)
|
||||
|
||||
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
|
||||
inputs = self.get_dummy_training_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output.norm().backward()
|
||||
|
||||
for module in quantized_model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
self.assertTrue(module.adapter[1].weight.grad is not None)
|
||||
|
||||
|
||||
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.6
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "float8"}
|
||||
|
||||
|
||||
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.6
|
||||
_test_torch_compile = True
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "int8"}
|
||||
|
||||
|
||||
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.55
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "int4"}
|
||||
|
||||
|
||||
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.65
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "int2"}
|
||||
@@ -50,7 +50,27 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -483,22 +503,6 @@ class TorchAoTest(unittest.TestCase):
|
||||
# there is additional overhead of scales and zero points
|
||||
self.assertTrue(total_bf16 < total_int4wo)
|
||||
|
||||
def test_model_memory_usage(self):
|
||||
model_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
expected_memory_saving_ratio = 2.0
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(device=torch_device)
|
||||
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
transformer_bf16.to(torch_device)
|
||||
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
|
||||
del transformer_bf16
|
||||
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_int8wo.to(torch_device)
|
||||
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from diffusers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def get_memory_consumption_stat(model, inputs):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model(**inputs)
|
||||
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
|
||||
return max_memory_mem_allocated
|
||||
@@ -1,61 +0,0 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
model_class = SanaTransformer2DModel
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
|
||||
)
|
||||
alternate_keys_ckpt_paths = [
|
||||
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
|
||||
]
|
||||
|
||||
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user