Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c73c00610e | |||
| 233dffdc3f | |||
| be2070991f | |||
| bf9a641f1a | |||
| a756694bf0 | |||
| d41388145e | |||
| a6288a5571 | |||
| 7d4db57037 | |||
| 902008608a | |||
| c8ee4af228 | |||
| b64ca6c11c | |||
| e12d610faa | |||
| bf6eaa8aec | |||
| 17128c42a4 | |||
| dbc1d505f0 | |||
| 151b74cd77 | |||
| 41ba8c0bf6 | |||
| 3191248472 | |||
| 648d968cfc | |||
| b756ec6e80 | |||
| d8825e7697 | |||
| 074798b299 | |||
| 3ee966950b | |||
| 9764f229d4 | |||
| 1826a1e7d3 | |||
| 0ed09a17bb | |||
| 2f7a417d1f | |||
| 4450d26b63 | |||
| f781b8c30c | |||
| 9c0e20de61 | |||
| f35a38725b | |||
| f66bd3261c | |||
| c4c99c3907 | |||
| 862a7d5038 | |||
| 8304adce2a | |||
| b389f339ec | |||
| e222246b4e | |||
| 83709d5a06 | |||
| 8eb73c872a | |||
| 88b015dc9f | |||
| 63cdf9c0ba |
@@ -46,7 +46,7 @@ jobs:
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
${CONDA_RUN} python -m uv pip install transformers --upgrade
|
||||
|
||||
@@ -238,6 +238,8 @@
|
||||
title: Textual Inversion
|
||||
- local: api/loaders/unet
|
||||
title: UNet
|
||||
- local: api/loaders/transformer_sd3
|
||||
title: SD3Transformer2D
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
title: Loaders
|
||||
@@ -400,6 +402,8 @@
|
||||
title: DiT
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
|
||||
@@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech
|
||||
|
||||
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
|
||||
|
||||
## JointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
|
||||
|
||||
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
|
||||
## SD3IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
|
||||
- all
|
||||
- is_ip_adapter_active
|
||||
|
||||
## IPAdapterMaskProcessor
|
||||
|
||||
[[autodoc]] image_processor.IPAdapterMaskProcessor
|
||||
@@ -0,0 +1,29 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# SD3Transformer2D
|
||||
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
|
||||
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
||||
|
||||
<Tip>
|
||||
|
||||
To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## SD3Transformer2DLoadersMixin
|
||||
|
||||
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
- all
|
||||
- _load_ip_adapter_weights
|
||||
@@ -0,0 +1,89 @@
|
||||
<!--Copyright 2024 The HuggingFace Team, The Black Forest Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FluxControlInpaint
|
||||
|
||||
FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
|
||||
|
||||
FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
|
||||
|
||||
| Control type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
|
||||
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxControlInpaintPipeline
|
||||
from diffusers.models.transformers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
pipe = FluxControlInpaintPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Depth-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# use following lines if you have GPU constraints
|
||||
# ---------------------------------------------------------------
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.enable_model_cpu_offload()
|
||||
# ---------------------------------------------------------------
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "a blue robot singing opera with human-like expressions"
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
head_mask = np.zeros_like(image)
|
||||
head_mask[65:580,300:642] = 255
|
||||
mask_image = Image.fromarray(head_mask)
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(image)[0].convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=30,
|
||||
strength=0.9,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
|
||||
```
|
||||
|
||||
## FluxControlInpaintPipeline
|
||||
[[autodoc]] FluxControlInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## FluxPipelineOutput
|
||||
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
|
||||
@@ -268,6 +268,43 @@ images = pipe(
|
||||
images[0].save("flux-redux.png")
|
||||
```
|
||||
|
||||
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
|
||||
|
||||
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
|
||||
|
||||
```py
|
||||
from diffusers import FluxControlPipeline
|
||||
from image_gen_aux import DepthPreprocessor
|
||||
from diffusers.utils import load_image
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
|
||||
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
|
||||
control_pipe.load_lora_weights(
|
||||
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
|
||||
)
|
||||
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
|
||||
control_pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(control_image)[0].convert("RGB")
|
||||
|
||||
image = control_pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -61,6 +61,44 @@ pipe = LTXImageToVideoPipeline.from_single_file(
|
||||
)
|
||||
```
|
||||
|
||||
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
|
||||
)
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = LTXPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=704,
|
||||
height=480,
|
||||
num_frames=161,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
```
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# Mochi
|
||||
# Mochi 1 Preview
|
||||
|
||||
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
|
||||
|
||||
@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
## Generating videos with Mochi-1 Preview
|
||||
|
||||
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(prompt, num_frames=85).frames[0]
|
||||
|
||||
export_to_video(frames, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Using a lower precision variant to save memory
|
||||
|
||||
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
frames = pipe(prompt, num_frames=85).frames[0]
|
||||
|
||||
export_to_video(frames, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Reproducing the results from the Genmo Mochi repo
|
||||
|
||||
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
|
||||
|
||||
<Tip>
|
||||
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
|
||||
|
||||
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
|
||||
pipe.enable_vae_tiling()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
|
||||
pipe.encode_prompt(prompt=prompt)
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16):
|
||||
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||
frames = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
guidance_scale=4.5,
|
||||
num_inference_steps=64,
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=163,
|
||||
generator=torch.Generator("cuda").manual_seed(0),
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
video_processor = VideoProcessor(vae_scale_factor=8)
|
||||
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
frames = frames / pipe.vae.config.scaling_factor
|
||||
|
||||
with torch.no_grad():
|
||||
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
|
||||
|
||||
video = video_processor.postprocess_video(video)[0]
|
||||
export_to_video(video, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Running inference with multiple GPUs
|
||||
|
||||
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "genmo/mochi-1-preview"
|
||||
transformer = MochiTransformer3DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
device_map="auto",
|
||||
max_memory={0: "24GB", 1: "24GB"}
|
||||
)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(
|
||||
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
|
||||
negative_prompt="",
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=85,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_videos_per_prompt=1,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
max_sequence_length=256,
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Using single file loading with the Mochi Transformer
|
||||
|
||||
You can use `from_single_file` to load the Mochi transformer in its original format.
|
||||
|
||||
<Tip>
|
||||
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "genmo/mochi-1-preview"
|
||||
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
|
||||
|
||||
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(
|
||||
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
|
||||
negative_prompt="",
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=85,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_videos_per_prompt=1,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
max_sequence_length=256,
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## MochiPipeline
|
||||
|
||||
[[autodoc]] MochiPipeline
|
||||
|
||||
@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
|
||||
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
|
||||
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
|
||||
|
||||
## Image Prompting with IP-Adapters
|
||||
|
||||
An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
|
||||
|
||||
- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
|
||||
- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
|
||||
- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
|
||||
|
||||
IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import SiglipVisionModel, SiglipImageProcessor
|
||||
|
||||
image_encoder_id = "google/siglip-so400m-patch14-384"
|
||||
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
|
||||
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
image_encoder_id,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
image_encoder_id,
|
||||
torch_dtype=torch.float16
|
||||
).to( "cuda")
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
torch_dtype=torch.float16,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
).to("cuda")
|
||||
|
||||
pipe.load_ip_adapter(ip_adapter_id)
|
||||
pipe.set_ip_adapter_scale(0.6)
|
||||
|
||||
ref_img = Image.open("image.jpg").convert('RGB')
|
||||
|
||||
image = pipe(
|
||||
width=1024,
|
||||
height=1024,
|
||||
prompt="a cat",
|
||||
negative_prompt="lowres, low quality, worst quality",
|
||||
num_inference_steps=24,
|
||||
guidance_scale=5.0,
|
||||
ip_adapter_image=ref_img
|
||||
).images[0]
|
||||
|
||||
image.save("result.jpg")
|
||||
```
|
||||
|
||||
<div class="justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/>
|
||||
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption>
|
||||
</div>
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Memory Optimisations for SD3
|
||||
|
||||
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
|
||||
### Running Inference with Model Offloading
|
||||
|
||||
|
||||
@@ -45,12 +45,11 @@ transformer = FluxTransformer2DModel.from_single_file(
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
generator=torch.manual_seed(0),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be
|
||||
## When to use what?
|
||||
|
||||
Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes.md)
|
||||
- [TorchAO](./torchao.md)
|
||||
- [GGUF](./gguf.md)
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -73,7 +73,7 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
|
||||
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sana-lora"
|
||||
|
||||
@@ -124,4 +124,4 @@ We provide several options for optimizing memory optimization:
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
|
||||
@@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import safetensors.torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
vision = True
|
||||
else:
|
||||
vision = False
|
||||
|
||||
"""
|
||||
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
|
||||
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
|
||||
--filename "flux-ip-adapter.safetensors"
|
||||
--output_path "flux-ip-adapter-hf/"
|
||||
"""
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--output_path", type=str)
|
||||
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
|
||||
converted_state_dict = {}
|
||||
|
||||
# image_proj
|
||||
## norm
|
||||
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
||||
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
||||
## proj
|
||||
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
||||
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"ip_adapter.{i}."
|
||||
# to_k_ip
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
|
||||
)
|
||||
# to_v_ip
|
||||
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args)
|
||||
|
||||
num_layers = 19
|
||||
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
|
||||
|
||||
print("Saving Flux IP-Adapter in Diffusers format.")
|
||||
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
|
||||
|
||||
if vision:
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
|
||||
model.save_pretrained(f"{args.output_path}/image_encoder")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
|
||||
@@ -265,9 +266,9 @@ if __name__ == "__main__":
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024],
|
||||
choices=[512, 1024, 2048],
|
||||
required=False,
|
||||
help="Image size of pretrained model, 512 or 1024.",
|
||||
help="Image size of pretrained model, 512, 1024 or 2048.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
|
||||
@@ -277,6 +277,7 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
"FluxControlNetInpaintPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
@@ -765,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
|
||||
@@ -55,7 +55,8 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
|
||||
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
@@ -70,10 +71,15 @@ if is_torch_available():
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
@@ -81,15 +87,22 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import IPAdapterMixin
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
|
||||
@@ -33,15 +33,20 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
JointAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -348,3 +353,519 @@ class IPAdapterMixin:
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||
for key in f.keys():
|
||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||
diffusers_name = ".".join(key.split(".")[1:])
|
||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||
diffusers_name = (
|
||||
".".join(key.split(".")[1:])
|
||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||
.replace("processor.", "")
|
||||
)
|
||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_pretrained_model_name_or_path is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||
image_encoder = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_pretrained_model_name_or_path,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
.to(self.device, dtype=image_encoder_dtype)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
|
||||
def LinearStrengthModel(start, finish, size):
|
||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||
|
||||
|
||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
transformer = self.transformer
|
||||
if not isinstance(scale, list):
|
||||
scale = [[scale] * transformer.config.num_layers]
|
||||
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
|
||||
if len(scale) != transformer.config.num_layers:
|
||||
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
|
||||
scale = [scale]
|
||||
|
||||
scale_configs = scale
|
||||
|
||||
key_id = 0
|
||||
for attn_name, attn_processor in transformer.attn_processors.items():
|
||||
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
attn_processor.scale[i] = scale_config[key_id]
|
||||
key_id += 1
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.transformer.encoder_hid_proj = None
|
||||
self.transformer.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class SD3IPAdapterMixin:
|
||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||
|
||||
@property
|
||||
def is_ip_adapter_active(self) -> bool:
|
||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||
|
||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||
the image context is irrelevant.
|
||||
|
||||
Returns:
|
||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||
"""
|
||||
scales = [
|
||||
attn_proc.scale
|
||||
for attn_proc in self.transformer.attn_processors.values()
|
||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||
]
|
||||
|
||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# Load the main state dict first
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
# Commons args for loading image encoder and image processor
|
||||
kwargs = {
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# Load IP-Adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||
"""
|
||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
scale (float):
|
||||
IP-Adapter scale to be set.
|
||||
|
||||
"""
|
||||
for attn_processor in self.transformer.attn_processors.values():
|
||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self) -> None:
|
||||
"""
|
||||
Unloads the IP Adapter weights.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# Remove image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=None)
|
||||
|
||||
# Remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=None)
|
||||
|
||||
# Remove image projection
|
||||
self.transformer.image_proj = None
|
||||
|
||||
# Restore original attention processors layers
|
||||
attn_procs = {
|
||||
name: (
|
||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||
)
|
||||
for name, value in self.transformer.attn_processors.items()
|
||||
}
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
@@ -643,7 +643,11 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
|
||||
],
|
||||
)
|
||||
|
||||
if "down" in old_key:
|
||||
|
||||
@@ -1863,6 +1863,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
||||
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
||||
)
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
@@ -2309,16 +2312,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Expand transformer parameter shapes if they don't match lora
|
||||
has_param_with_shape_update = False
|
||||
|
||||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
||||
for name, module in transformer.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module_weight = module.weight.data
|
||||
module_bias = module.bias.data if module.bias is not None else None
|
||||
bias = module_bias is not None
|
||||
|
||||
lora_A_weight_name = f"{name}.lora_A.weight"
|
||||
lora_B_weight_name = f"{name}.lora_B.weight"
|
||||
if lora_A_weight_name not in state_dict.keys():
|
||||
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
|
||||
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
|
||||
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
|
||||
if lora_A_weight_name not in state_dict:
|
||||
continue
|
||||
|
||||
in_features = state_dict[lora_A_weight_name].shape[1]
|
||||
@@ -2329,57 +2333,106 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
continue
|
||||
|
||||
module_out_features, module_in_features = module_weight.shape
|
||||
if out_features < module_out_features or in_features < module_in_features:
|
||||
raise NotImplementedError(
|
||||
f"Only LoRAs with input/output features higher than the current module's input/output features "
|
||||
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
|
||||
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
|
||||
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
debug_message = ""
|
||||
if in_features > module_in_features:
|
||||
debug_message += (
|
||||
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
|
||||
f"checkpoint contains higher number of features than expected. The number of input_features will be "
|
||||
f"expanded from {module_in_features} to {in_features}"
|
||||
)
|
||||
|
||||
debug_message = (
|
||||
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
|
||||
f"checkpoint contains higher number of features than expected. The number of input_features will be "
|
||||
f"expanded from {module_in_features} to {in_features}"
|
||||
)
|
||||
if module_out_features != out_features:
|
||||
if out_features > module_out_features:
|
||||
debug_message += (
|
||||
", and the number of output features will be "
|
||||
f"expanded from {module_out_features} to {out_features}."
|
||||
)
|
||||
else:
|
||||
debug_message += "."
|
||||
logger.debug(debug_message)
|
||||
if debug_message:
|
||||
logger.debug(debug_message)
|
||||
|
||||
has_param_with_shape_update = True
|
||||
parent_module_name, _, current_module_name = name.rpartition(".")
|
||||
parent_module = transformer.get_submodule(parent_module_name)
|
||||
if out_features > module_out_features or in_features > module_in_features:
|
||||
has_param_with_shape_update = True
|
||||
parent_module_name, _, current_module_name = name.rpartition(".")
|
||||
parent_module = transformer.get_submodule(parent_module_name)
|
||||
|
||||
# TODO: consider initializing this under meta device for optims.
|
||||
expanded_module = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
|
||||
)
|
||||
# Only weights are expanded and biases are not.
|
||||
new_weight = torch.zeros_like(
|
||||
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
||||
)
|
||||
slices = tuple(slice(0, dim) for dim in module_weight.shape)
|
||||
new_weight[slices] = module_weight
|
||||
expanded_module.weight.data.copy_(new_weight)
|
||||
if module_bias is not None:
|
||||
expanded_module.bias.data.copy_(module_bias)
|
||||
with torch.device("meta"):
|
||||
expanded_module = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias, dtype=module_weight.dtype
|
||||
)
|
||||
# Only weights are expanded and biases are not. This is because only the input dimensions
|
||||
# are changed while the output dimensions remain the same. The shape of the weight tensor
|
||||
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
|
||||
# explains the reason why only weights are expanded.
|
||||
new_weight = torch.zeros_like(
|
||||
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
||||
)
|
||||
slices = tuple(slice(0, dim) for dim in module_weight.shape)
|
||||
new_weight[slices] = module_weight
|
||||
tmp_state_dict = {"weight": new_weight}
|
||||
if module_bias is not None:
|
||||
tmp_state_dict["bias"] = module_bias
|
||||
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
|
||||
|
||||
setattr(parent_module, current_module_name, expanded_module)
|
||||
setattr(parent_module, current_module_name, expanded_module)
|
||||
|
||||
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
||||
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
||||
new_value = int(expanded_module.weight.data.shape[1])
|
||||
old_value = getattr(transformer.config, attribute_name)
|
||||
setattr(transformer.config, attribute_name, new_value)
|
||||
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
|
||||
del tmp_state_dict
|
||||
|
||||
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
||||
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
||||
new_value = int(expanded_module.weight.data.shape[1])
|
||||
old_value = getattr(transformer.config, attribute_name)
|
||||
setattr(transformer.config, attribute_name, new_value)
|
||||
logger.info(
|
||||
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
||||
)
|
||||
|
||||
return has_param_with_shape_update
|
||||
|
||||
@classmethod
|
||||
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
|
||||
expanded_module_names = set()
|
||||
transformer_state_dict = transformer.state_dict()
|
||||
prefix = f"{cls.transformer_name}."
|
||||
|
||||
lora_module_names = [
|
||||
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
|
||||
]
|
||||
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
|
||||
lora_module_names = sorted(set(lora_module_names))
|
||||
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
|
||||
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
|
||||
if unexpected_modules:
|
||||
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
||||
|
||||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
||||
for k in lora_module_names:
|
||||
if k in unexpected_modules:
|
||||
continue
|
||||
|
||||
base_param_name = (
|
||||
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
|
||||
)
|
||||
base_weight_param = transformer_state_dict[base_param_name]
|
||||
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
|
||||
|
||||
if base_weight_param.shape[1] > lora_A_param.shape[1]:
|
||||
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
|
||||
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
|
||||
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
|
||||
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
|
||||
expanded_module_names.add(k)
|
||||
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
|
||||
raise NotImplementedError(
|
||||
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
|
||||
if expanded_module_names:
|
||||
logger.info(
|
||||
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
|
||||
return lora_state_dict
|
||||
|
||||
|
||||
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
||||
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
||||
@@ -3870,6 +3923,314 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`HunyuanVideoTransformer3DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -53,6 +53,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ from .single_file_utils import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
@@ -96,6 +97,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -99,13 +99,15 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
],
|
||||
"ltx-video": [
|
||||
(
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
),
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
"patchify_proj.weight",
|
||||
"transformer_blocks.27.scale_shift_table",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -151,12 +153,15 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -587,11 +592,17 @@ def infer_diffusers_model_type(checkpoint):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
model_type = "flux-dev"
|
||||
if checkpoint["img_in.weight"].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
|
||||
elif checkpoint["img_in.weight"].shape[1] == 128:
|
||||
model_type = "flux-depth"
|
||||
else:
|
||||
model_type = "flux-dev"
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
@@ -610,6 +621,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "autoencoder-dc-f128c512"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
||||
model_type = "mochi-1-preview"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1750,6 +1764,12 @@ def swap_scale_shift(weight, dim):
|
||||
return new_weight
|
||||
|
||||
|
||||
def swap_proj_gate(weight):
|
||||
proj, gate = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([gate, proj], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def get_attn2_layers(state_dict):
|
||||
attn2_layers = []
|
||||
for key in state_dict.keys():
|
||||
@@ -2247,9 +2267,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {
|
||||
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
|
||||
}
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"model.diffusion_model.": "",
|
||||
@@ -2406,3 +2424,101 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict = {}
|
||||
|
||||
# Comfy checkpoints add this prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Convert patch_embed
|
||||
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Convert time_embed
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
|
||||
# Convert transformer blocks
|
||||
num_layers = 48
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
old_prefix = f"blocks.{i}."
|
||||
|
||||
# norm1
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
|
||||
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
else:
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
|
||||
# Context attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_y.weight"
|
||||
)
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
|
||||
|
||||
# Output layers
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict
|
||||
|
||||
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict["ip_adapter"].items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
||||
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
||||
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
|
||||
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = state_dict["image_proj"]["latents"].shape[1]
|
||||
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
self.image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
|
||||
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
|
||||
@@ -575,7 +575,7 @@ class Attention(nn.Module):
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
@@ -2653,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = None
|
||||
# for ip-adapter
|
||||
# TODO: support for multiple adapters
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_attn_output = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_attn_output = scale * ip_attn_output
|
||||
ip_attn_output = ip_attn_output.to(ip_query.dtype)
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@@ -5243,6 +5386,177 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
"""
|
||||
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
|
||||
additional image-based information and timestep embeddings.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The number of hidden channels.
|
||||
ip_hidden_states_dim (`int`):
|
||||
The image feature dimension.
|
||||
head_dim (`int`):
|
||||
The number of head channels.
|
||||
timesteps_emb_dim (`int`, defaults to 1280):
|
||||
The number of input channels for timestep embedding.
|
||||
scale (`float`, defaults to 0.5):
|
||||
IP-Adapter scale.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ip_hidden_states_dim: int,
|
||||
head_dim: int,
|
||||
timesteps_emb_dim: int = 1280,
|
||||
scale: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import
|
||||
from .normalization import AdaLayerNorm, RMSNorm
|
||||
|
||||
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
|
||||
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
||||
self.norm_q = RMSNorm(head_dim, 1e-6)
|
||||
self.norm_k = RMSNorm(head_dim, 1e-6)
|
||||
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
||||
self.scale = scale
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
ip_hidden_states: torch.FloatTensor = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
|
||||
|
||||
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
|
||||
|
||||
Args:
|
||||
attn (`Attention`):
|
||||
Attention instance.
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
The encoder hidden states.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Attention mask.
|
||||
ip_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Image embeddings.
|
||||
temb (`torch.FloatTensor`, *optional*):
|
||||
Timestep embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Output hidden states.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
img_query = query
|
||||
img_key = key
|
||||
img_value = value
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP Adapter
|
||||
if self.scale != 0 and ip_hidden_states is not None:
|
||||
# Norm image features
|
||||
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
|
||||
|
||||
# To k and v
|
||||
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
||||
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
||||
|
||||
# Reshape
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Norm
|
||||
query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
ip_key = self.norm_ip_k(ip_key)
|
||||
|
||||
# cat img
|
||||
key = torch.cat([img_key, ip_key], dim=2)
|
||||
value = torch.cat([img_value, ip_value], dim=2)
|
||||
|
||||
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + ip_hidden_states * self.scale
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
@@ -5725,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
SlicedAttnProcessor,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
AttentionProcessor = Union[
|
||||
@@ -5772,6 +6087,7 @@ AttentionProcessor = Union[
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
|
||||
@@ -168,6 +168,7 @@ class HunyuanVideoResnetBlockCausal3D(nn.Module):
|
||||
self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
@@ -792,12 +793,12 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 256
|
||||
self.tile_sample_min_width = 256
|
||||
self.tile_sample_min_num_frames = 64
|
||||
self.tile_sample_min_num_frames = 16
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
self.tile_sample_stride_num_frames = 48
|
||||
self.tile_sample_stride_num_frames = 12
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
|
||||
@@ -1003,7 +1004,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
for i in range(0, height, self.tile_sample_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, self.tile_sample_stride_width):
|
||||
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
tile = self.encoder(tile)
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
@@ -1020,7 +1021,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
@@ -691,7 +691,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
output_type="pt",
|
||||
)
|
||||
pos_embedding = pos_embedding.flatten(0, 1)
|
||||
joint_pos_embedding = torch.zeros(
|
||||
joint_pos_embedding = pos_embedding.new_zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
||||
@@ -1535,7 +1535,7 @@ class ImageProjection(nn.Module):
|
||||
batch_size = image_embeds.shape[0]
|
||||
|
||||
# image
|
||||
image_embeds = self.image_embeds(image_embeds)
|
||||
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
|
||||
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
||||
image_embeds = self.norm(image_embeds)
|
||||
return image_embeds
|
||||
@@ -2396,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class IPAdapterTimeImageProjectionBlock(nn.Module):
|
||||
"""Block for IPAdapterTimeImageProjection.
|
||||
|
||||
Args:
|
||||
hidden_dim (`int`, defaults to 1280):
|
||||
The number of hidden channels.
|
||||
dim_head (`int`, defaults to 64):
|
||||
The number of head channels.
|
||||
heads (`int`, defaults to 20):
|
||||
Parallel attention heads.
|
||||
ffn_ratio (`int`, defaults to 4):
|
||||
The expansion ratio of feedforward network hidden layer channels.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int = 1280,
|
||||
dim_head: int = 64,
|
||||
heads: int = 20,
|
||||
ffn_ratio: int = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.ln0 = nn.LayerNorm(hidden_dim)
|
||||
self.ln1 = nn.LayerNorm(hidden_dim)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_dim,
|
||||
cross_attention_dim=hidden_dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
)
|
||||
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
|
||||
|
||||
# AdaLayerNorm
|
||||
self.adaln_silu = nn.SiLU()
|
||||
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
|
||||
self.adaln_norm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
# Set attention scale and fuse KV
|
||||
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
|
||||
self.attn.fuse_projections()
|
||||
self.attn.to_k = None
|
||||
self.attn.to_v = None
|
||||
|
||||
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Image features.
|
||||
latents (`torch.Tensor`):
|
||||
Latent features.
|
||||
timestep_emb (`torch.Tensor`):
|
||||
Timestep embedding.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Output latent features.
|
||||
"""
|
||||
|
||||
# Shift and scale for AdaLayerNorm
|
||||
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
|
||||
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
|
||||
|
||||
# Fused Attention
|
||||
residual = latents
|
||||
x = self.ln0(x)
|
||||
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
query = self.attn.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // self.attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
latents = weight @ value
|
||||
|
||||
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
|
||||
latents = self.attn.to_out[0](latents)
|
||||
latents = self.attn.to_out[1](latents)
|
||||
latents = latents + residual
|
||||
|
||||
## FeedForward
|
||||
residual = latents
|
||||
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
return self.ff(latents) + residual
|
||||
|
||||
|
||||
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
class IPAdapterTimeImageProjection(nn.Module):
|
||||
"""Resampler of SD3 IP-Adapter with timestep embedding.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`, defaults to 1152):
|
||||
The feature dimension.
|
||||
output_dim (`int`, defaults to 2432):
|
||||
The number of output channels.
|
||||
hidden_dim (`int`, defaults to 1280):
|
||||
The number of hidden channels.
|
||||
depth (`int`, defaults to 4):
|
||||
The number of blocks.
|
||||
dim_head (`int`, defaults to 64):
|
||||
The number of head channels.
|
||||
heads (`int`, defaults to 20):
|
||||
Parallel attention heads.
|
||||
num_queries (`int`, defaults to 64):
|
||||
The number of queries.
|
||||
ffn_ratio (`int`, defaults to 4):
|
||||
The expansion ratio of feedforward network hidden layer channels.
|
||||
timestep_in_dim (`int`, defaults to 320):
|
||||
The number of input channels for timestep embedding.
|
||||
timestep_flip_sin_to_cos (`bool`, defaults to True):
|
||||
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
|
||||
timestep_freq_shift (`int`, defaults to 0):
|
||||
Controls the timestep delta between frequencies between dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 1152,
|
||||
output_dim: int = 2432,
|
||||
hidden_dim: int = 1280,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 20,
|
||||
num_queries: int = 64,
|
||||
ffn_ratio: int = 4,
|
||||
timestep_in_dim: int = 320,
|
||||
timestep_flip_sin_to_cos: bool = True,
|
||||
timestep_freq_shift: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
|
||||
self.proj_in = nn.Linear(embed_dim, hidden_dim)
|
||||
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
self.layers = nn.ModuleList(
|
||||
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
||||
)
|
||||
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Image features.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep in denoising process.
|
||||
Returns:
|
||||
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
||||
"""
|
||||
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
||||
timestep_emb = self.time_embedding(timestep_emb)
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
x = x + timestep_emb[:, None]
|
||||
|
||||
for block in self.layers:
|
||||
latents = block(x, latents, timestep_emb)
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
|
||||
return latents, timestep_emb
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
super().__init__()
|
||||
|
||||
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
except StopIteration:
|
||||
try:
|
||||
return next(parameter.buffers()).dtype
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
"""
|
||||
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
||||
"""
|
||||
last_dtype = None
|
||||
for param in parameter.parameters():
|
||||
last_dtype = param.dtype
|
||||
if param.is_floating_point():
|
||||
return param.dtype
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
for buffer in parameter.buffers():
|
||||
last_dtype = buffer.dtype
|
||||
if buffer.is_floating_point():
|
||||
return buffer.dtype
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
if last_dtype is not None:
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
if last_tuple is not None:
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
@@ -802,7 +820,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
)
|
||||
if hf_quantizer is not None:
|
||||
if hf_quantizer is not None and is_bnb_quantization_method:
|
||||
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
|
||||
@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
||||
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..embeddings import (
|
||||
@@ -32,6 +33,9 @@ from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideoAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -496,7 +500,47 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of dual-stream blocks to use.
|
||||
num_single_layers (`int`, defaults to `40`):
|
||||
The number of layers of single-stream blocks to use.
|
||||
num_refiner_layers (`int`, defaults to `2`):
|
||||
The number of layers of refiner blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
qk_norm (`str`, defaults to `rms_norm`):
|
||||
The normalization to use for the query and key projections in the attention layers.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings in the model.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The dimension of the pooled projection of the text embeddings.
|
||||
rope_theta (`float`, defaults to `256.0`):
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -630,8 +674,24 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
pooled_projections: torch.Tensor,
|
||||
guidance: torch.Tensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -717,6 +777,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
@@ -304,7 +305,7 @@ class MochiRoPE(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
@@ -334,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MochiTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
@@ -103,7 +103,9 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
@@ -349,8 +351,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
||||
Embeddings projected from the embeddings of input conditions.
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states (`list` of `torch.Tensor`):
|
||||
@@ -390,6 +392,12 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
|
||||
|
||||
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Skip specified layers
|
||||
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
||||
@@ -411,11 +419,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
joint_attention_kwargs,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
elif not is_skip:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
|
||||
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
|
||||
@@ -128,6 +128,7 @@ else:
|
||||
]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -539,6 +540,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
|
||||
@@ -35,9 +35,12 @@ from .controlnet import (
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
FluxControlPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxPipeline,
|
||||
@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
||||
("auraflow", AuraFlowPipeline),
|
||||
("flux", FluxPipeline),
|
||||
("flux-control", FluxControlPipeline),
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("lumina", LuminaText2ImgPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
("flux", FluxImg2ImgPipeline),
|
||||
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
||||
("flux-control", FluxControlImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
("flux-control", FluxControlInpaintPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
]
|
||||
)
|
||||
@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
if "ControlPipeline" in orig_class_name:
|
||||
to_replace = "ControlPipeline"
|
||||
else:
|
||||
to_replace = "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
|
||||
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
|
||||
else:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
|
||||
if "enable_pag" in kwargs:
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
|
||||
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
|
||||
|
||||
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
|
||||
# the `orig_class_name` can be:
|
||||
# `- *Pipeline` (for regular text-to-image checkpoint)
|
||||
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
||||
# `- *Img2ImgPipeline` (for refiner checkpoint)
|
||||
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
|
||||
if "Img2Img" in orig_class_name:
|
||||
to_replace = "Img2ImgPipeline"
|
||||
elif "ControlPipeline" in orig_class_name:
|
||||
to_replace = "ControlPipeline"
|
||||
else:
|
||||
to_replace = "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
if enable_pag:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
||||
|
||||
if to_replace == "ControlPipeline":
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
|
||||
|
||||
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
|
||||
# The `orig_class_name`` can be:
|
||||
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
|
||||
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
||||
# - or *Pipeline (for regular text-to-image checkpoint)
|
||||
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
|
||||
if "Inpaint" in orig_class_name:
|
||||
to_replace = "InpaintPipeline"
|
||||
elif "ControlPipeline" in orig_class_name:
|
||||
to_replace = "ControlPipeline"
|
||||
else:
|
||||
to_replace = "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
||||
if to_replace == "ControlPipeline":
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
|
||||
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
|
||||
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
||||
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
|
||||
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
|
||||
_import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
|
||||
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_flux import FluxPipeline
|
||||
from .pipeline_flux_control import FluxControlPipeline
|
||||
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
|
||||
from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
|
||||
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
|
||||
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
||||
|
||||
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -142,6 +149,7 @@ class FluxPipeline(
|
||||
FluxLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
FluxIPAdapterMixin,
|
||||
):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
@@ -169,8 +177,8 @@ class FluxPipeline(
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -182,6 +190,8 @@ class FluxPipeline(
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -193,6 +203,8 @@ class FluxPipeline(
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
@@ -377,14 +389,60 @@ class FluxPipeline(
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
@@ -419,10 +477,33 @@ class FluxPipeline(
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
@@ -551,6 +632,9 @@ class FluxPipeline(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
true_cfg_scale: float = 1.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
@@ -561,6 +645,12 @@ class FluxPipeline(
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -610,6 +700,17 @@ class FluxPipeline(
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_ip_adapter_image:
|
||||
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -647,8 +748,12 @@ class FluxPipeline(
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -670,6 +775,7 @@ class FluxPipeline(
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
@@ -684,6 +790,21 @@ class FluxPipeline(
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
if do_true_cfg:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
_,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -725,12 +846,43 @@ class FluxPipeline(
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
||||
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
negative_ip_adapter_image,
|
||||
negative_ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
@@ -746,6 +898,22 @@ class FluxPipeline(
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -403,7 +403,6 @@ class FluxControlPipeline(
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1095,7 +1095,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# predict the noise residual
|
||||
if self.controlnet.config.guidance_embeds:
|
||||
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
||||
else:
|
||||
use_guidance = self.controlnet.config.guidance_embeds
|
||||
if use_guidance:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import HunyuanVideoLoraLoaderMixin
|
||||
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
@@ -132,7 +133,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using HunyuanVideo.
|
||||
|
||||
@@ -142,7 +143,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
text_encoder ([`LlamaModel`]):
|
||||
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
tokenizer_2 (`LlamaTokenizer`):
|
||||
tokenizer (`LlamaTokenizer`):
|
||||
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
||||
transformer ([`HunyuanVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
@@ -447,6 +448,10 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -471,6 +476,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
@@ -525,6 +531,10 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
@@ -562,6 +572,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
@@ -640,6 +651,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: MochiTransformer3DModel,
|
||||
force_zeros_for_empty_prompt: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -205,10 +206,11 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
|
||||
)
|
||||
self.default_height = 480
|
||||
self.default_width = 848
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -236,7 +238,11 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
if prompt == "" or prompt[-1] == "":
|
||||
|
||||
# The original Mochi implementation zeros out empty negative prompts
|
||||
# but this can lead to overflow when placing the entire pipeline under the autocast context
|
||||
# adding this here so that we can enable zeroing prompts if necessary
|
||||
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
|
||||
text_input_ids = torch.zeros_like(text_input_ids, device=device)
|
||||
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -811,6 +812,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
expected_types = pipeline_class._get_signature_types()
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
@@ -833,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for key in init_dict.keys():
|
||||
if key not in passed_class_obj:
|
||||
continue
|
||||
if "scheduler" in key:
|
||||
continue
|
||||
|
||||
class_obj = passed_class_obj[key]
|
||||
_expected_class_types = []
|
||||
for expected_type in expected_types[key]:
|
||||
if isinstance(expected_type, enum.EnumMeta):
|
||||
_expected_class_types.extend(expected_type.__members__.keys())
|
||||
else:
|
||||
_expected_class_types.append(expected_type.__name__)
|
||||
|
||||
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
|
||||
if not _is_valid_type:
|
||||
logger.warning(
|
||||
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
|
||||
)
|
||||
|
||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -446,7 +446,7 @@ class StableAudioPipeline(DiffusionPipeline):
|
||||
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
|
||||
)
|
||||
|
||||
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
|
||||
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
|
||||
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
|
||||
|
||||
# check num_channels
|
||||
|
||||
@@ -255,7 +255,12 @@ class StableDiffusionPipeline(
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
|
||||
is_unet_sample_size_less_64 = (
|
||||
hasattr(unet.config, "sample_size")
|
||||
and self._is_unet_config_sample_size_int
|
||||
and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -902,8 +907,18 @@ class StableDiffusionPipeline(
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
if not height or not width:
|
||||
height = (
|
||||
self.unet.config.sample_size
|
||||
if self._is_unet_config_sample_size_int
|
||||
else self.unet.config.sample_size[0]
|
||||
)
|
||||
width = (
|
||||
self.unet.config.sample_size
|
||||
if self._is_unet_config_sample_size_int
|
||||
else self.unet.config.sample_size[1]
|
||||
)
|
||||
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hooks
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -142,7 +144,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
|
||||
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
@@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -191,6 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -204,6 +212,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
@@ -683,6 +693,83 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
|
||||
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
|
||||
"""Encodes the given image into a feature representation using a pre-trained image encoder.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
Input image to be encoded.
|
||||
device: (`torch.device`):
|
||||
Torch device.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The encoded image feature representation.
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=self.dtype)
|
||||
|
||||
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Prepares image embeddings for use in the IP-Adapter.
|
||||
|
||||
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
|
||||
|
||||
Args:
|
||||
ip_adapter_image (`PipelineImageInput`, *optional*):
|
||||
The input image to extract features from for IP-Adapter.
|
||||
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
||||
Precomputed image embeddings.
|
||||
device: (`torch.device`, *optional*):
|
||||
Torch device.
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
Number of images that should be generated per prompt.
|
||||
do_classifier_free_guidance (`bool`, defaults to True):
|
||||
Whether to use classifier free guidance or not.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if ip_adapter_image_embeds is not None:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
|
||||
else:
|
||||
single_image_embeds = ip_adapter_image_embeds
|
||||
elif ip_adapter_image is not None:
|
||||
single_image_embeds = self.encode_image(ip_adapter_image, device)
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
|
||||
else:
|
||||
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
|
||||
|
||||
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
||||
|
||||
return image_embeds.to(device=device)
|
||||
|
||||
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
||||
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
|
||||
logger.warning(
|
||||
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
|
||||
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
|
||||
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
|
||||
)
|
||||
|
||||
super().enable_sequential_cpu_offload(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -705,6 +792,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -713,9 +802,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
skip_guidance_layers: List[int] = None,
|
||||
skip_layer_guidance_scale: int = 2.8,
|
||||
skip_layer_guidance_stop: int = 0.2,
|
||||
skip_layer_guidance_start: int = 0.01,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
@@ -781,6 +870,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
|
||||
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
|
||||
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -938,7 +1032,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
# 6. Prepare image embeddings
|
||||
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
|
||||
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
|
||||
else:
|
||||
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
|
||||
@@ -289,6 +289,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
|
||||
@@ -291,14 +291,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
elif self.config.use_flow_sigmas:
|
||||
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_max = (
|
||||
|
||||
@@ -318,6 +318,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
|
||||
@@ -381,6 +381,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
|
||||
@@ -392,6 +392,21 @@ class FluxControlImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
|
||||
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
|
||||
|
||||
ignore_patterns = ["*.json", "*.md"]
|
||||
if not local_files_only:
|
||||
# `model_info` call must guarded with the above condition.
|
||||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
|
||||
for shard_file in original_shard_filenames:
|
||||
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
||||
if not shard_file_present:
|
||||
raise EnvironmentError(
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
except HTTPError as e:
|
||||
# `model_info` call must guarded with the above condition.
|
||||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
|
||||
for shard_file in original_shard_filenames:
|
||||
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
||||
if not shard_file_present:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
|
||||
" again after checking your internet connection."
|
||||
) from e
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
|
||||
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
|
||||
elif local_files_only:
|
||||
_check_if_shards_exist_locally(
|
||||
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
|
||||
try:
|
||||
# Load from URL
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cache_dir, subfolder)
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
except HTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
|
||||
" again after checking your internet connection."
|
||||
) from e
|
||||
|
||||
return cached_folder, sharded_metadata
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
@@ -37,9 +36,6 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
pass
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
@@ -340,21 +340,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
|
||||
}
|
||||
# We should error out because lora input features is less than original. We only
|
||||
# support expanding the module, not shrinking it
|
||||
with self.assertRaises(NotImplementedError):
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -430,10 +415,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
|
||||
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
|
||||
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
|
||||
def test_lora_expanding_shape_with_normal_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
|
||||
# tested with it.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
@@ -478,27 +463,18 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
|
||||
# input features before expansion. This should raise an error about the weight shapes being incompatible.
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size mismatch for x_embedder.lora_A.adapter-2.weight",
|
||||
pipe.load_lora_weights,
|
||||
lora_state_dict,
|
||||
"adapter-2",
|
||||
)
|
||||
# We should have `adapter-1` as the only adapter.
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
|
||||
# Check if the output is the same after lora loading error
|
||||
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
|
||||
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
|
||||
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
|
||||
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
|
||||
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
|
||||
# This should raise a runtime error on input shapes being incompatible.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
@@ -521,14 +497,11 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
@@ -546,6 +519,98 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"adapter-2",
|
||||
)
|
||||
|
||||
def test_fuse_expanded_lora_with_regular_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
|
||||
# tested with it.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
components["transformer"] = transformer
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
|
||||
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
|
||||
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
|
||||
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_load_regular_lora(self):
|
||||
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
|
||||
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
|
||||
# transformers include Flux Fill, Flux Control, etc.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
@@ -760,3 +825,40 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora_with_turbo(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
|
||||
)
|
||||
else:
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
|
||||
)
|
||||
|
||||
image = self.pipeline(
|
||||
prompt=self.prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = image[0, -3:, -3:, -1].flatten()
|
||||
if "Canny" in lora_ckpt_id:
|
||||
expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
|
||||
else:
|
||||
expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 10,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"num_refiner_layers": 1,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
}
|
||||
transformer_cls = HunyuanVideoTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"down_block_types": (
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
),
|
||||
"up_block_types": (
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": 1,
|
||||
"act_fn": "silu",
|
||||
"norm_num_groups": 4,
|
||||
"scaling_factor": 0.476986,
|
||||
"spatial_compression_ratio": 8,
|
||||
"temporal_compression_ratio": 4,
|
||||
"mid_block_add_attention": True,
|
||||
}
|
||||
vae_cls = AutoencoderKLHunyuanVideo
|
||||
has_two_text_encoders = True
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = (
|
||||
LlamaTokenizerFast,
|
||||
"hf-internal-testing/tiny-random-hunyuanvideo",
|
||||
"tokenizer",
|
||||
)
|
||||
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = (
|
||||
CLIPTokenizer,
|
||||
"hf-internal-testing/tiny-random-hunyuanvideo",
|
||||
"tokenizer_2",
|
||||
)
|
||||
text_encoder_cls, text_encoder_id, text_encoder_subfolder = (
|
||||
LlamaModel,
|
||||
"hf-internal-testing/tiny-random-hunyuanvideo",
|
||||
"text_encoder",
|
||||
)
|
||||
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = (
|
||||
CLIPTextModel,
|
||||
"hf-internal-testing/tiny-random-hunyuanvideo",
|
||||
"text_encoder_2",
|
||||
)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 4
|
||||
num_frames = 9
|
||||
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
|
||||
sizes = (4, 4)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "",
|
||||
"num_frames": num_frames,
|
||||
"num_inference_steps": 1,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": sequence_length,
|
||||
"prompt_template": {"template": "{}", "crop_start": 0},
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
prompt=inputs["prompt"],
|
||||
height=inputs["height"],
|
||||
width=inputs["width"],
|
||||
num_frames=inputs["num_frames"],
|
||||
num_inference_steps=inputs["num_inference_steps"],
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
output_type="np",
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
# TODO(aryan): Fix the following test
|
||||
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
@@ -23,7 +23,6 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
@@ -31,9 +30,6 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
pass
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
@@ -29,7 +29,6 @@ from diffusers import (
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
is_peft_available,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
@@ -37,9 +36,6 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
pass
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
+22
-12
@@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
has_two_text_encoders = False
|
||||
has_three_text_encoders = False
|
||||
text_encoder_cls, text_encoder_id = None, None
|
||||
text_encoder_2_cls, text_encoder_2_id = None, None
|
||||
text_encoder_3_cls, text_encoder_3_id = None, None
|
||||
tokenizer_cls, tokenizer_id = None, None
|
||||
tokenizer_2_cls, tokenizer_2_id = None, None
|
||||
tokenizer_3_cls, tokenizer_3_id = None, None
|
||||
text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
|
||||
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
|
||||
text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
|
||||
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
|
||||
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
|
||||
|
||||
unet_kwargs = None
|
||||
transformer_cls = None
|
||||
@@ -124,16 +124,26 @@ class PeftLoraLoaderMixinTests:
|
||||
torch.manual_seed(0)
|
||||
vae = self.vae_cls(**self.vae_kwargs)
|
||||
|
||||
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
|
||||
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
|
||||
text_encoder = self.text_encoder_cls.from_pretrained(
|
||||
self.text_encoder_id, subfolder=self.text_encoder_subfolder
|
||||
)
|
||||
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
|
||||
|
||||
if self.text_encoder_2_cls is not None:
|
||||
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
|
||||
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
|
||||
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
|
||||
self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
|
||||
)
|
||||
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
|
||||
self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
|
||||
)
|
||||
|
||||
if self.text_encoder_3_cls is not None:
|
||||
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
|
||||
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
|
||||
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
|
||||
self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
|
||||
)
|
||||
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
|
||||
self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
|
||||
)
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
|
||||
@@ -43,10 +43,14 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
|
||||
"down_block_types": (
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
),
|
||||
"up_block_types": (
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": 1,
|
||||
@@ -154,6 +158,27 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# We need to overwrite this test because the base test does not account length of down_block_types
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 16, 16, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@@ -146,7 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder"}
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
|
||||
@@ -65,7 +65,7 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Encoder", "TemporalDecoder"}
|
||||
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Test unsupported.")
|
||||
|
||||
@@ -803,7 +803,7 @@ class ModelTesterMixin:
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
@@ -850,6 +850,8 @@ class ModelTesterMixin:
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
|
||||
|
||||
@@ -18,6 +18,8 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -26,6 +28,56 @@ from ..test_modeling_common import ModelTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_flux_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
continue
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=model.config["pooled_projection_dim"],
|
||||
num_image_text_embeds=4,
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = {}
|
||||
sd = image_projection.state_dict()
|
||||
ip_image_projection_state_dict.update(
|
||||
{
|
||||
"proj.weight": sd["image_embeds.weight"],
|
||||
"proj.bias": sd["image_embeds.bias"],
|
||||
"norm.weight": sd["norm.weight"],
|
||||
"norm.bias": sd["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@@ -105,6 +105,23 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"AttnUpBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"UNetMidBlock2D",
|
||||
"UpBlock2D",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 8
|
||||
block_out_channels = (16, 32)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
@@ -220,6 +237,17 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 32
|
||||
block_out_channels = (32, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
@@ -329,3 +357,17 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
|
||||
block_out_channels = (32, 64, 64, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
|
||||
|
||||
@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
|
||||
|
||||
# Load the EMA model from the saved directory
|
||||
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
|
||||
loaded_ema_unet.to(torch_device)
|
||||
|
||||
# Check that the shadow parameters of the loaded model match the original EMA model
|
||||
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
|
||||
@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
|
||||
|
||||
# Load the EMA model from the saved directory
|
||||
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
|
||||
loaded_ema_unet.to(torch_device)
|
||||
|
||||
# Check that the shadow parameters of the loaded model match the original EMA model
|
||||
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
|
||||
|
||||
@@ -16,13 +16,14 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -91,6 +92,8 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
@@ -296,3 +299,112 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
|
||||
weight_name = "ip_adapter.safetensors"
|
||||
ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
prompt_embeds = torch.load(
|
||||
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
|
||||
)
|
||||
pooled_prompt_embeds = torch.load(
|
||||
hf_hub_download(
|
||||
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
|
||||
)
|
||||
)
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
return {
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"pooled_prompt_embeds": pooled_prompt_embeds,
|
||||
"negative_prompt_embeds": negative_prompt_embeds,
|
||||
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
|
||||
"ip_adapter_image": ip_adapter_image,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.5,
|
||||
"true_cfg_scale": 4.0,
|
||||
"max_sequence_length": 256,
|
||||
"output_type": "np",
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
def test_flux_ip_adapter_inference(self):
|
||||
pipe = self.pipeline_class.from_pretrained(
|
||||
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
|
||||
)
|
||||
pipe.load_ip_adapter(
|
||||
self.ip_adapter_repo_id,
|
||||
weight_name=self.weight_name,
|
||||
image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path,
|
||||
)
|
||||
pipe.set_ip_adapter_scale(1.0)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.1855,
|
||||
0.1680,
|
||||
0.1406,
|
||||
0.1953,
|
||||
0.1699,
|
||||
0.1465,
|
||||
0.2012,
|
||||
0.1738,
|
||||
0.1484,
|
||||
0.2051,
|
||||
0.1797,
|
||||
0.1523,
|
||||
0.2012,
|
||||
0.1719,
|
||||
0.1445,
|
||||
0.2070,
|
||||
0.1777,
|
||||
0.1465,
|
||||
0.2090,
|
||||
0.1836,
|
||||
0.1484,
|
||||
0.2129,
|
||||
0.1875,
|
||||
0.1523,
|
||||
0.2090,
|
||||
0.1816,
|
||||
0.1484,
|
||||
0.2110,
|
||||
0.1836,
|
||||
0.1543,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
|
||||
|
||||
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = FluxControlInpaintPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = FluxTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=8,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=32,
|
||||
axes_dims_rope=[4, 4, 8],
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
image = Image.new("RGB", (8, 8), 0)
|
||||
control_image = Image.new("RGB", (8, 8), 0)
|
||||
mask_image = Image.new("RGB", (8, 8), 255)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"control_image": control_image,
|
||||
"generator": generator,
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"strength": 0.8,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 30.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
# def test_flux_different_prompts(self):
|
||||
# pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
# inputs = self.get_dummy_inputs(torch_device)
|
||||
# output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
# inputs = self.get_dummy_inputs(torch_device)
|
||||
# inputs["prompt_2"] = "a different prompt"
|
||||
# output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
# max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# # Outputs should be different here
|
||||
# # For some reasons, they don't show large differences
|
||||
# assert max_diff > 1e-6
|
||||
|
||||
def test_flux_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
device=torch_device,
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
@@ -275,7 +275,7 @@ class MochiPipelineIntegrationTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_cogvideox(self):
|
||||
def test_mochi(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16)
|
||||
|
||||
@@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests(
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
def test_pipeline_accept_tuple_type_unet_sample_size(self):
|
||||
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
|
||||
sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
sample_size = [60, 80]
|
||||
customised_unet = UNet2DConditionModel(sample_size=sample_size)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
|
||||
assert pipe.unet.config.sample_size == sample_size
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -103,6 +103,8 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
|
||||
@@ -1802,6 +1802,16 @@ class PipelineFastTests(unittest.TestCase):
|
||||
sd.maybe_free_model_hooks()
|
||||
assert sd._offload_gpu_id == 5
|
||||
|
||||
def test_wrong_model(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
|
||||
)
|
||||
|
||||
assert "is of type" in str(error_context.exception)
|
||||
assert "but should be" in str(error_context.exception)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
|
||||
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
|
||||
@@ -54,6 +54,7 @@ from ..models.autoencoders.vae import (
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
|
||||
from ..models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_faceid_state_dict,
|
||||
create_ip_adapter_state_dict,
|
||||
@@ -483,6 +484,94 @@ class IPAdapterTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class FluxIPAdapterTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
It provides a set of common tests for pipelines that support IP Adapters.
|
||||
"""
|
||||
|
||||
def test_pipeline_signature(self):
|
||||
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||
|
||||
assert issubclass(self.pipeline_class, FluxIPAdapterMixin)
|
||||
self.assertIn(
|
||||
"ip_adapter_image",
|
||||
parameters,
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method",
|
||||
)
|
||||
self.assertIn(
|
||||
"ip_adapter_image_embeds",
|
||||
parameters,
|
||||
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
|
||||
return torch.randn((1, 1, image_embed_dim), device=torch_device)
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
inputs["negative_prompt"] = ""
|
||||
inputs["true_cfg_scale"] = 4.0
|
||||
inputs["output_type"] = "np"
|
||||
inputs["return_dict"] = False
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
image_embed_dim = pipe.transformer.config.pooled_projection_dim
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
|
||||
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
|
||||
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs)[0]
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
|
||||
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs)[0]
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
self.assertLess(
|
||||
max_diff_without_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
|
||||
@@ -278,13 +278,14 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertEqual(weight.quant_max, 15)
|
||||
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||
|
||||
def test_offload(self):
|
||||
def test_device_map(self):
|
||||
"""
|
||||
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
|
||||
that the device map is correctly set (in the `hf_device_map` attribute of the model).
|
||||
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
|
||||
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
|
||||
correctly set (in the `hf_device_map` attribute of the model).
|
||||
"""
|
||||
|
||||
device_map_offload = {
|
||||
custom_device_map_dict = {
|
||||
"time_text_embed": torch_device,
|
||||
"context_embedder": torch_device,
|
||||
"x_embedder": torch_device,
|
||||
@@ -293,27 +294,50 @@ class TorchAoTest(unittest.TestCase):
|
||||
"norm_out": torch_device,
|
||||
"proj_out": "cpu",
|
||||
}
|
||||
device_maps = ["auto", custom_device_map_dict]
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map_offload,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
for device_map in device_maps:
|
||||
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
|
||||
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
|
||||
# Test non-sharded model
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
|
||||
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Test sharded model
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
|
||||
Reference in New Issue
Block a user