Compare commits

...

14 Commits

Author SHA1 Message Date
yiyixuxu c73c00610e add:
q
2024-12-22 10:21:00 +01:00
Mehmet Yiğit Özgenç 233dffdc3f flux controlnet inpaint config bug (#10291)
* flux controlnet inpaint config bug

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

---------

Co-authored-by: yigitozgenc <yigit@quantuslabs.ai>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-21 18:44:43 +00:00
hlky be2070991f Support Flux IP Adapter (#10261)
* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-21 17:49:58 +00:00
hlky bf9a641f1a Fix EMAModel test_from_pretrained (#10325) 2024-12-21 14:10:44 +00:00
hlky a756694bf0 Fix push_tests_mps.yml (#10326) 2024-12-21 14:10:32 +00:00
Sayak Paul d41388145e [Docs] Update gguf.md to remove generator from the pipeline from_pretrained (#10299)
Update gguf.md to remove generator from the pipeline from_pretrained
2024-12-21 07:15:03 +05:30
Junsong Chen a6288a5571 [Sana]add 2K related model for Sana (#10322)
add 2K related model for Sana
2024-12-20 07:21:34 -10:00
Steven Liu 7d4db57037 [docs] Fix quantization links (#10323)
Update overview.md
2024-12-20 08:30:21 -08:00
Aditya Raj 902008608a [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)
[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float

Co-authored-by: hlky <hlky@hlky.ac>
2024-12-20 15:29:58 +00:00
Leojc c8ee4af228 docs: fix a mistake in docstring (#10319)
Update pipeline_hunyuan_video.py

docs: fix a mistake
2024-12-20 15:22:32 +00:00
Sayak Paul b64ca6c11c [Docs] Update ltx_video.md to remove generator from from_pretrained() (#10316)
Update ltx_video.md to remove generator from `from_pretrained()`
2024-12-20 18:32:22 +05:30
Dhruv Nair e12d610faa Mochi docs (#9934)
* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-20 16:27:38 +05:30
Sayak Paul bf6eaa8aec [Tests] add integration tests for lora expansion stuff in Flux. (#10318)
add integration tests for lora expansion stuff in Flux.
2024-12-20 16:14:58 +05:30
Sayak Paul 17128c42a4 [LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill (#10259)
* lora expansion with dummy zeros.

* updates

* fix working 🥳

* working.

* use torch.device meta for state dict expansion.

* tests

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>

* fixes

* fixes

* switch to debug

* fix

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

* fix stuff

* docs

---------

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-20 14:30:32 +05:30
26 changed files with 1676 additions and 123 deletions
+1 -1
View File
@@ -46,7 +46,7 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
${CONDA_RUN} python -m uv pip install -e [quality,test]
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade
+37
View File
@@ -268,6 +268,43 @@ images = pipe(
images[0].save("flux-redux.png")
```
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
```py
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = control_pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
```
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
@@ -79,7 +79,6 @@ transformer = LTXVideoTransformer3DModel.from_single_file(
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
+196 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
-->
# Mochi
# Mochi 1 Preview
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
## Generating videos with Mochi-1 Preview
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
```python
import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
# Enable memory savings
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
frames = pipe(prompt, num_frames=85).frames[0]
export_to_video(frames, "mochi.mp4", fps=30)
```
## Using a lower precision variant to save memory
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
```python
import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
# Enable memory savings
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
frames = pipe(prompt, num_frames=85).frames[0]
export_to_video(frames, "mochi.mp4", fps=30)
```
## Reproducing the results from the Genmo Mochi repo
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
<Tip>
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
</Tip>
<Tip>
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
</Tip>
```python
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
with torch.no_grad():
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
pipe.encode_prompt(prompt=prompt)
)
with torch.autocast("cuda", torch.bfloat16):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
frames = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
guidance_scale=4.5,
num_inference_steps=64,
height=480,
width=848,
num_frames=163,
generator=torch.Generator("cuda").manual_seed(0),
output_type="latent",
return_dict=False,
)[0]
video_processor = VideoProcessor(vae_scale_factor=8)
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
latents_std = (
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
else:
frames = frames / pipe.vae.config.scaling_factor
with torch.no_grad():
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
video = video_processor.postprocess_video(video)[0]
export_to_video(video, "mochi.mp4", fps=30)
```
## Running inference with multiple GPUs
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
```python
import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
model_id = "genmo/mochi-1-preview"
transformer = MochiTransformer3DModel.from_pretrained(
model_id,
subfolder="transformer",
device_map="auto",
max_memory={0: "24GB", 1: "24GB"}
)
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
frames = pipe(
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
negative_prompt="",
height=480,
width=848,
num_frames=85,
num_inference_steps=50,
guidance_scale=4.5,
num_videos_per_prompt=1,
generator=torch.Generator(device="cuda").manual_seed(0),
max_sequence_length=256,
output_type="pil",
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
## Using single file loading with the Mochi Transformer
You can use `from_single_file` to load the Mochi transformer in its original format.
<Tip>
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
</Tip>
```python
import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
model_id = "genmo/mochi-1-preview"
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
frames = pipe(
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
negative_prompt="",
height=480,
width=848,
num_frames=85,
num_inference_steps=50,
guidance_scale=4.5,
num_videos_per_prompt=1,
generator=torch.Generator(device="cuda").manual_seed(0),
max_sequence_length=256,
output_type="pil",
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
## MochiPipeline
[[autodoc]] MochiPipeline
+1 -2
View File
@@ -45,12 +45,11 @@ transformer = FluxTransformer2DModel.from_single_file(
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt).images[0]
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
+3 -3
View File
@@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes.md)
- [TorchAO](./torchao.md)
- [GGUF](./gguf.md)
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
@@ -0,0 +1,97 @@
import argparse
from contextlib import nullcontext
import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
if is_transformers_available():
from transformers import CLIPVisionModelWithProjection
vision = True
else:
vision = False
"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}
# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")
if __name__ == "__main__":
main(args)
+3 -2
View File
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
@@ -265,9 +266,9 @@ if __name__ == "__main__":
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
choices=[512, 1024, 2048],
required=False,
help="Image size of pretrained model, 512 or 1024.",
help="Image size of pretrained model, 512, 1024 or 2048.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
+4 -1
View File
@@ -55,7 +55,7 @@ _import_structure = {}
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
@@ -77,6 +77,7 @@ if is_torch_available():
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
]
@@ -86,12 +87,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
+286
View File
@@ -38,6 +38,8 @@ if is_transformers_available():
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -353,6 +355,290 @@ class IPAdapterMixin:
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
Can be either:
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
for key in f.keys():
if any(key.startswith(prefix) for prefix in image_proj_keys):
diffusers_name = ".".join(key.split(".")[1:])
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
diffusers_name = (
".".join(key.split(".")[1:])
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
.replace("processor.", "")
)
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_pretrained_model_name_or_path is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
image_encoder = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_pretrained_model_name_or_path,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.to(self.device, dtype=image_encoder_dtype)
.eval()
)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
def LinearStrengthModel(start, finish, size):
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
transformer = self.transformer
if not isinstance(scale, list):
scale = [[scale] * transformer.config.num_layers]
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
if len(scale) != transformer.config.num_layers:
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
scale = [scale]
scale_configs = scale
key_id = 0
for attn_name, attn_processor in transformer.attn_processors.items():
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
attn_processor.scale[i] = scale_config[key_id]
key_id += 1
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove hidden encoder
self.transformer.encoder_hid_proj = None
self.transformer.config.encoder_hid_dim_type = None
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters."""
+94 -41
View File
@@ -1863,6 +1863,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer(
@@ -2309,16 +2312,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None
lora_A_weight_name = f"{name}.lora_A.weight"
lora_B_weight_name = f"{name}.lora_B.weight"
if lora_A_weight_name not in state_dict.keys():
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
if lora_A_weight_name not in state_dict:
continue
in_features = state_dict[lora_A_weight_name].shape[1]
@@ -2329,57 +2333,106 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
continue
module_out_features, module_in_features = module_weight.shape
if out_features < module_out_features or in_features < module_in_features:
raise NotImplementedError(
f"Only LoRAs with input/output features higher than the current module's input/output features "
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
debug_message = ""
if in_features > module_in_features:
debug_message += (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
if out_features > module_out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
logger.debug(debug_message)
if debug_message:
logger.debug(debug_message)
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
# TODO: consider initializing this under meta device for optims.
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
# explains the reason why only weights are expanded.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
tmp_state_dict["bias"] = module_bias
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
setattr(parent_module, current_module_name, expanded_module)
setattr(parent_module, current_module_name, expanded_module)
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
return has_param_with_shape_update
@classmethod
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
expanded_module_names = set()
transformer_state_dict = transformer.state_dict()
prefix = f"{cls.transformer_name}."
lora_module_names = [
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
]
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
lora_module_names = sorted(set(lora_module_names))
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for k in lora_module_names:
if k in unexpected_modules:
continue
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
if base_weight_param.shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
if expanded_module_names:
logger.info(
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
)
return lora_state_dict
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
+179
View File
@@ -0,0 +1,179 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
if is_accelerate_available():
pass
logger = logging.get_logger(__name__)
class FluxTransformer2DLoadersMixin:
"""
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 0
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
num_image_text_embed = 4
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
num_image_text_embed = 16
# IP-Adapter
num_image_text_embeds += [num_image_text_embed]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
dtype=self.dtype,
device=self.device,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = self.device
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
+145 -1
View File
@@ -575,7 +575,7 @@ class Attention(nn.Module):
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
@@ -2653,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU:
return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -5896,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = (
SlicedAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
)
AttentionProcessor = Union[
+1 -1
View File
@@ -1535,7 +1535,7 @@ class ImageProjection(nn.Module):
batch_size = image_embeds.shape[0]
# image
image_embeds = self.image_embeds(image_embeds)
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
+31 -13
View File
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
return next(parameter.parameters()).dtype
except StopIteration:
try:
return next(parameter.buffers()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
if param.is_floating_point():
return param.dtype
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
for buffer in parameter.buffers():
last_dtype = buffer.dtype
if buffer.is_floating_point():
return buffer.dtype
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -21,7 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attn_output, context_attn_output = self.attn(
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
):
"""
The Transformer model introduced in Flux.
@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
+173 -5
View File
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,6 +149,7 @@ class FluxPipeline(
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
FluxIPAdapterMixin,
):
r"""
The Flux pipeline for text-to-image generation.
@@ -169,8 +177,8 @@ class FluxPipeline(
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
@@ -182,6 +190,8 @@ class FluxPipeline(
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
@@ -193,6 +203,8 @@ class FluxPipeline(
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -377,14 +389,60 @@ class FluxPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
image_embeds.append(single_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -419,10 +477,33 @@ class FluxPipeline(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -551,6 +632,9 @@ class FluxPipeline(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
@@ -561,6 +645,12 @@ class FluxPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -610,6 +700,17 @@ class FluxPipeline(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -647,8 +748,12 @@ class FluxPipeline(
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -670,6 +775,7 @@ class FluxPipeline(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
@@ -684,6 +790,21 @@ class FluxPipeline(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -725,12 +846,43 @@ class FluxPipeline(
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -746,6 +898,22 @@ class FluxPipeline(
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -403,7 +403,6 @@ class FluxControlPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -1095,7 +1095,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# predict the noise residual
if self.controlnet.config.guidance_embeds:
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds
if use_guidance:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
@@ -143,7 +143,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Args:
text_encoder ([`LlamaModel`]):
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
tokenizer_2 (`LlamaTokenizer`):
tokenizer (`LlamaTokenizer`):
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
transformer ([`HunyuanVideoTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
@@ -446,7 +446,7 @@ class StableAudioPipeline(DiffusionPipeline):
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
# check num_channels
+144 -42
View File
@@ -340,21 +340,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
pipe.load_lora_weights(lora_state_dict, "adapter-1")
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -430,10 +415,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
def test_lora_expanding_shape_with_normal_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
@@ -478,27 +463,18 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
# input features before expansion. This should raise an error about the weight shapes being incompatible.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
# We should have `adapter-1` as the only adapter.
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-2")
# Check if the output is the same after lora loading error
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
# This should raise a runtime error on input shapes being incompatible.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
@@ -521,14 +497,11 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
@@ -546,6 +519,98 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"adapter-2",
)
def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
components["transformer"] = transformer
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
def test_load_regular_lora(self):
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@@ -760,3 +825,40 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora_with_turbo(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
)
else:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
)
image = self.pipeline(
prompt=self.prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=self.num_inference_steps,
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = image[0, -3:, -3:, -1].flatten()
if "Canny" in lora_ckpt_id:
expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
else:
expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
@@ -18,6 +18,8 @@ import unittest
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
@@ -26,6 +28,56 @@ from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
+2
View File
@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
+113 -1
View File
@@ -16,13 +16,14 @@ from diffusers.utils.testing_utils import (
)
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
@@ -91,6 +92,8 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -296,3 +299,112 @@ class FluxPipelineSlowTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
@slow
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
weight_name = "ip_adapter.safetensors"
ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
prompt_embeds = torch.load(
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
)
pooled_prompt_embeds = torch.load(
hf_hub_download(
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
)
)
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8)
return {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
"ip_adapter_image": ip_adapter_image,
"num_inference_steps": 2,
"guidance_scale": 3.5,
"true_cfg_scale": 4.0,
"max_sequence_length": 256,
"output_type": "np",
"generator": generator,
}
def test_flux_ip_adapter_inference(self):
pipe = self.pipeline_class.from_pretrained(
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
)
pipe.load_ip_adapter(
self.ip_adapter_repo_id,
weight_name=self.weight_name,
image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path,
)
pipe.set_ip_adapter_scale(1.0)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.1855,
0.1680,
0.1406,
0.1953,
0.1699,
0.1465,
0.2012,
0.1738,
0.1484,
0.2051,
0.1797,
0.1523,
0.2012,
0.1719,
0.1445,
0.2070,
0.1777,
0.1465,
0.2090,
0.1836,
0.1484,
0.2129,
0.1875,
0.1523,
0.2090,
0.1816,
0.1484,
0.2110,
0.1836,
0.1543,
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
+90 -1
View File
@@ -29,7 +29,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -54,6 +54,7 @@ from ..models.autoencoders.vae import (
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
@@ -483,6 +484,94 @@ class IPAdapterTesterMixin:
)
class FluxIPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""
def test_pipeline_signature(self):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
assert issubclass(self.pipeline_class, FluxIPAdapterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter_image_embeds",
parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)
def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
return torch.randn((1, 1, image_embed_dim), device=torch_device)
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice
adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.