Compare commits

..

41 Commits

Author SHA1 Message Date
yiyixuxu c73c00610e add:
q
2024-12-22 10:21:00 +01:00
Mehmet Yiğit Özgenç 233dffdc3f flux controlnet inpaint config bug (#10291)
* flux controlnet inpaint config bug

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

---------

Co-authored-by: yigitozgenc <yigit@quantuslabs.ai>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-21 18:44:43 +00:00
hlky be2070991f Support Flux IP Adapter (#10261)
* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-21 17:49:58 +00:00
hlky bf9a641f1a Fix EMAModel test_from_pretrained (#10325) 2024-12-21 14:10:44 +00:00
hlky a756694bf0 Fix push_tests_mps.yml (#10326) 2024-12-21 14:10:32 +00:00
Sayak Paul d41388145e [Docs] Update gguf.md to remove generator from the pipeline from_pretrained (#10299)
Update gguf.md to remove generator from the pipeline from_pretrained
2024-12-21 07:15:03 +05:30
Junsong Chen a6288a5571 [Sana]add 2K related model for Sana (#10322)
add 2K related model for Sana
2024-12-20 07:21:34 -10:00
Steven Liu 7d4db57037 [docs] Fix quantization links (#10323)
Update overview.md
2024-12-20 08:30:21 -08:00
Aditya Raj 902008608a [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)
[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float

Co-authored-by: hlky <hlky@hlky.ac>
2024-12-20 15:29:58 +00:00
Leojc c8ee4af228 docs: fix a mistake in docstring (#10319)
Update pipeline_hunyuan_video.py

docs: fix a mistake
2024-12-20 15:22:32 +00:00
Sayak Paul b64ca6c11c [Docs] Update ltx_video.md to remove generator from from_pretrained() (#10316)
Update ltx_video.md to remove generator from `from_pretrained()`
2024-12-20 18:32:22 +05:30
Dhruv Nair e12d610faa Mochi docs (#9934)
* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-20 16:27:38 +05:30
Sayak Paul bf6eaa8aec [Tests] add integration tests for lora expansion stuff in Flux. (#10318)
add integration tests for lora expansion stuff in Flux.
2024-12-20 16:14:58 +05:30
Sayak Paul 17128c42a4 [LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill (#10259)
* lora expansion with dummy zeros.

* updates

* fix working 🥳

* working.

* use torch.device meta for state dict expansion.

* tests

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>

* fixes

* fixes

* switch to debug

* fix

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

* fix stuff

* docs

---------

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-20 14:30:32 +05:30
Dhruv Nair dbc1d505f0 [Single File] Add GGUF support for LTX (#10298)
* update

* add docs.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-20 11:52:29 +05:30
Aryan 151b74cd77 Make tensors in ResNet contiguous for Hunyuan VAE (#10309)
contiguous tensors in resnet

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-20 11:45:37 +05:30
Aryan 41ba8c0bf6 Add support for sharded models when TorchAO quantization is enabled (#10256)
* add sharded + device_map check
2024-12-19 15:42:20 -10:00
Daniel Regado 3191248472 [WIP] SD3.5 IP-Adapter Pipeline Integration (#9987)
* Added support for single IPAdapter on SD3.5 pipeline



---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-19 14:48:18 -10:00
dg845 648d968cfc Enable Gradient Checkpointing for UNet2DModel (New) (#7201)
* Port UNet2DModel gradient checkpointing code from #6718.


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-19 14:45:45 -10:00
djm b756ec6e80 unet's sample_size attribute is to accept tuple(h, w) in StableDiffusionPipeline (#10181) 2024-12-19 22:24:18 +00:00
Aryan d8825e7697 Fix failing lora tests after HunyuanVideo lora (#10307)
fix
2024-12-20 02:35:41 +05:30
hlky 074798b299 Fix local_files_only for checkpoints with shards (#10294) 2024-12-19 07:04:57 -10:00
Dhruv Nair 3ee966950b Allow Mochi Transformer to be split across multiple GPUs (#10300)
update
2024-12-19 22:34:44 +05:30
Dhruv Nair 9764f229d4 [Single File] Add single file support for Mochi Transformer (#10268)
update
2024-12-19 22:20:40 +05:30
Shenghai Yuan 1826a1e7d3 [LoRA] Support HunyuanVideo (#10254)
* 1217

* 1217

* 1217

* update

* reverse

* add test

* update test

* make style

* update

* make style

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-19 16:22:20 +05:30
hlky 0ed09a17bb Check correct model type is passed to from_pretrained (#10189)
* Check correct model type is passed to `from_pretrained`

* Flax, skip scheduler

* test_wrong_model

* Fix for scheduler

* Update tests/pipelines/test_pipelines.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* EnumMeta

* Flax

* scheduler in expected types

* make

* type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name'

* support union

* fix typing in kandinsky

* make

* add LCMScheduler

* 'LCMScheduler' object has no attribute 'sigmas'

* tests for wrong scheduler

* make

* update

* warning

* tests

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* import FlaxSchedulerMixin

* skip scheduler

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-12-19 09:24:52 +00:00
赵三石 2f7a417d1f Update lora_conversion_utils.py (#9980)
x-flux single-blocks lora load

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-18 23:07:50 -10:00
hlky 4450d26b63 Add Flux Control to AutoPipeline (#10292) 2024-12-18 22:28:56 -10:00
Aryan f781b8c30c Hunyuan VAE tiling fixes and transformer docs (#10295)
* update

* udpate

* fix test
2024-12-19 10:28:10 +05:30
Sayak Paul 9c0e20de61 [chore] Update README_sana.md to update the default model (#10285)
Update README_sana.md to update the default model
2024-12-19 10:24:57 +05:30
Aryan f35a38725b [tests] remove nullop import checks from lora tests (#10273)
remove nullop imports
2024-12-19 01:19:08 +05:30
Aryan f66bd3261c Rename Mochi integration test correctly (#10220)
rename integration test
2024-12-18 22:41:23 +05:30
Aryan c4c99c3907 [tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (#10270)
fix joint pos embedding device
2024-12-18 22:36:08 +05:30
Dhruv Nair 862a7d5038 [Single File] Add single file support for Flux Canny, Depth and Fill (#10288)
update
2024-12-18 19:19:47 +05:30
Dhruv Nair 8304adce2a Make zeroing prompt embeds for Mochi Pipeline configurable (#10284)
update
2024-12-18 18:32:53 +05:30
Dhruv Nair b389f339ec Fix Doc links in GGUF and Quantization overview docs (#10279)
* update

* Update docs/source/en/quantization/gguf.md

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-18 18:32:36 +05:30
hlky e222246b4e Fix sigma_last with use_flow_sigmas (#10267) 2024-12-18 12:22:10 +00:00
Andrés Romero 83709d5a06 Flux Control(Depth/Canny) + Inpaint (#10192)
* flux_control_inpaint - failing test_flux_different_prompts

* removing test_flux_different_prompts?

* fix style

* fix from PR comments

* fix style

* reducing guidance_scale in demo

* Update src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Co-authored-by: hlky <hlky@hlky.ac>

* make

* prepare_latents is not copied from

* update docs

* typos

---------

Co-authored-by: affromero <ubuntu@ip-172-31-17-146.ec2.internal>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-18 09:14:16 +00:00
Qin Zhou 8eb73c872a Support pass kwargs to sd3 custom attention processor (#9818)
* Support pass kwargs to sd3 custom attention processor


---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-17 21:58:33 -10:00
Xinyuan Zhao 88b015dc9f Make time_embed_dim of UNet2DModel changeable (#10262) 2024-12-17 21:55:18 -10:00
Sayak Paul 63cdf9c0ba [chore] fix: reamde -> readme (#10276)
fix: reamde -> readme
2024-12-18 10:56:08 +05:30
78 changed files with 5286 additions and 281 deletions
+1 -1
View File
@@ -46,7 +46,7 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
${CONDA_RUN} python -m uv pip install -e [quality,test]
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade
+4
View File
@@ -238,6 +238,8 @@
title: Textual Inversion
- local: api/loaders/unet
title: UNet
- local: api/loaders/transformer_sd3
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
title: Loaders
@@ -400,6 +402,8 @@
title: DiT
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
+2
View File
@@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
## JointAttnProcessor2_0
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
+6
View File
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
## SD3IPAdapterMixin
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
- all
- is_ip_adapter_active
## IPAdapterMaskProcessor
[[autodoc]] image_processor.IPAdapterMaskProcessor
@@ -0,0 +1,29 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SD3Transformer2D
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
<Tip>
To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
</Tip>
## SD3Transformer2DLoadersMixin
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
- all
- _load_ip_adapter_weights
@@ -0,0 +1,89 @@
<!--Copyright 2024 The HuggingFace Team, The Black Forest Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FluxControlInpaint
FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
| Control type | Developer | Link |
| -------- | ---------- | ---- |
| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
<Tip>
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
</Tip>
```python
import torch
from diffusers import FluxControlInpaintPipeline
from diffusers.models.transformers import FluxTransformer2DModel
from transformers import T5EncoderModel
from diffusers.utils import load_image, make_image_grid
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
from PIL import Image
import numpy as np
pipe = FluxControlInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
torch_dtype=torch.bfloat16,
)
# use following lines if you have GPU constraints
# ---------------------------------------------------------------
transformer = FluxTransformer2DModel.from_pretrained(
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
# ---------------------------------------------------------------
pipe.to("cuda")
prompt = "a blue robot singing opera with human-like expressions"
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
head_mask = np.zeros_like(image)
head_mask[65:580,300:642] = 255
mask_image = Image.fromarray(head_mask)
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(image)[0].convert("RGB")
output = pipe(
prompt=prompt,
image=image,
control_image=control_image,
mask_image=mask_image,
num_inference_steps=30,
strength=0.9,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
```
## FluxControlInpaintPipeline
[[autodoc]] FluxControlInpaintPipeline
- all
- __call__
## FluxPipelineOutput
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
+37
View File
@@ -268,6 +268,43 @@ images = pipe(
images[0].save("flux-redux.png")
```
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
```py
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = control_pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
```
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
+38
View File
@@ -61,6 +61,44 @@ pipe = LTXImageToVideoPipeline.from_single_file(
)
```
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
```py
import torch
from diffusers.utils import export_to_video
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
)
transformer = LTXVideoTransformer3DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=704,
height=480,
num_frames=161,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
```
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
## LTXPipeline
+196 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
-->
# Mochi
# Mochi 1 Preview
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
## Generating videos with Mochi-1 Preview
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
```python
import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
# Enable memory savings
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
frames = pipe(prompt, num_frames=85).frames[0]
export_to_video(frames, "mochi.mp4", fps=30)
```
## Using a lower precision variant to save memory
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
```python
import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
# Enable memory savings
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
frames = pipe(prompt, num_frames=85).frames[0]
export_to_video(frames, "mochi.mp4", fps=30)
```
## Reproducing the results from the Genmo Mochi repo
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
<Tip>
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
</Tip>
<Tip>
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
</Tip>
```python
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
with torch.no_grad():
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
pipe.encode_prompt(prompt=prompt)
)
with torch.autocast("cuda", torch.bfloat16):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
frames = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
guidance_scale=4.5,
num_inference_steps=64,
height=480,
width=848,
num_frames=163,
generator=torch.Generator("cuda").manual_seed(0),
output_type="latent",
return_dict=False,
)[0]
video_processor = VideoProcessor(vae_scale_factor=8)
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
latents_std = (
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
else:
frames = frames / pipe.vae.config.scaling_factor
with torch.no_grad():
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
video = video_processor.postprocess_video(video)[0]
export_to_video(video, "mochi.mp4", fps=30)
```
## Running inference with multiple GPUs
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
```python
import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
model_id = "genmo/mochi-1-preview"
transformer = MochiTransformer3DModel.from_pretrained(
model_id,
subfolder="transformer",
device_map="auto",
max_memory={0: "24GB", 1: "24GB"}
)
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
frames = pipe(
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
negative_prompt="",
height=480,
width=848,
num_frames=85,
num_inference_steps=50,
guidance_scale=4.5,
num_videos_per_prompt=1,
generator=torch.Generator(device="cuda").manual_seed(0),
max_sequence_length=256,
output_type="pil",
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
## Using single file loading with the Mochi Transformer
You can use `from_single_file` to load the Mochi transformer in its original format.
<Tip>
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
</Tip>
```python
import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
model_id = "genmo/mochi-1-preview"
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
frames = pipe(
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
negative_prompt="",
height=480,
width=848,
num_frames=85,
num_inference_steps=50,
guidance_scale=4.5,
num_videos_per_prompt=1,
generator=torch.Generator(device="cuda").manual_seed(0),
max_sequence_length=256,
output_type="pil",
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
## MochiPipeline
[[autodoc]] MochiPipeline
@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
## Image Prompting with IP-Adapters
An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
```python
import torch
from PIL import Image
from diffusers import StableDiffusion3Pipeline
from transformers import SiglipVisionModel, SiglipImageProcessor
image_encoder_id = "google/siglip-so400m-patch14-384"
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
feature_extractor = SiglipImageProcessor.from_pretrained(
image_encoder_id,
torch_dtype=torch.float16
)
image_encoder = SiglipVisionModel.from_pretrained(
image_encoder_id,
torch_dtype=torch.float16
).to( "cuda")
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
torch_dtype=torch.float16,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
).to("cuda")
pipe.load_ip_adapter(ip_adapter_id)
pipe.set_ip_adapter_scale(0.6)
ref_img = Image.open("image.jpg").convert('RGB')
image = pipe(
width=1024,
height=1024,
prompt="a cat",
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=5.0,
ip_adapter_image=ref_img
).images[0]
image.save("result.jpg")
```
<div class="justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/>
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption>
</div>
<Tip>
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
</Tip>
## Memory Optimisations for SD3
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
### Running Inference with Model Offloading
+1 -2
View File
@@ -45,12 +45,11 @@ transformer = FluxTransformer2DModel.from_single_file(
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt).images[0]
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
+3 -3
View File
@@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes.md)
- [TorchAO](./torchao.md)
- [GGUF](./gguf.md)
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
@@ -73,7 +73,7 @@ This will also allow us to push the trained LoRA parameters to the Hugging Face
Now, we can launch training using:
```bash
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers"
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
@@ -124,4 +124,4 @@ We provide several options for optimizing memory optimization:
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
@@ -0,0 +1,97 @@
import argparse
from contextlib import nullcontext
import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
if is_transformers_available():
from transformers import CLIPVisionModelWithProjection
vision = True
else:
vision = False
"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}
# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")
if __name__ == "__main__":
main(args)
+3 -2
View File
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
@@ -265,9 +266,9 @@ if __name__ == "__main__":
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
choices=[512, 1024, 2048],
required=False,
help="Image size of pretrained model, 512 or 1024.",
help="Image size of pretrained model, 512, 1024 or 2048.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
+2
View File
@@ -277,6 +277,7 @@ else:
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
@@ -765,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
+16 -3
View File
@@ -55,7 +55,8 @@ _import_structure = {}
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
@@ -70,10 +71,15 @@ if is_torch_available():
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
"Mochi1LoraLoaderMixin",
"HunyuanVideoLoraLoaderMixin",
"SanaLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -81,15 +87,22 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
FluxLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Mochi1LoraLoaderMixin,
+529 -8
View File
@@ -33,15 +33,20 @@ from .unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available():
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
JointAttnProcessor2_0,
SD3IPAdapterJointAttnProcessor2_0,
)
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)
logger = logging.get_logger(__name__)
@@ -348,3 +353,519 @@ class IPAdapterMixin:
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
Can be either:
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
for key in f.keys():
if any(key.startswith(prefix) for prefix in image_proj_keys):
diffusers_name = ".".join(key.split(".")[1:])
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
diffusers_name = (
".".join(key.split(".")[1:])
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
.replace("processor.", "")
)
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_pretrained_model_name_or_path is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
image_encoder = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_pretrained_model_name_or_path,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.to(self.device, dtype=image_encoder_dtype)
.eval()
)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
def LinearStrengthModel(start, finish, size):
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
transformer = self.transformer
if not isinstance(scale, list):
scale = [[scale] * transformer.config.num_layers]
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
if len(scale) != transformer.config.num_layers:
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
scale = [scale]
scale_configs = scale
key_id = 0
for attn_name, attn_processor in transformer.attn_processors.items():
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
attn_processor.scale[i] = scale_config[key_id]
key_id += 1
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove hidden encoder
self.transformer.encoder_hid_proj = None
self.transformer.config.encoder_hid_dim_type = None
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters."""
@property
def is_ip_adapter_active(self) -> bool:
"""Checks if IP-Adapter is loaded and scale > 0.
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
the image context is irrelevant.
Returns:
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
"""
scales = [
attn_proc.scale
for attn_proc in self.transformer.attn_processors.values()
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
]
return len(scales) > 0 and any(scale > 0 for scale in scales)
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str = "ip-adapter.safetensors",
subfolder: Optional[str] = None,
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
) -> None:
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
weight_name (`str`, defaults to "ip-adapter.safetensors"):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
subfolder (`str`, *optional*):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# Load the main state dict first
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
# Commons args for loading image encoder and image processor
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage,
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# Load IP-Adapter into transformer
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: float) -> None:
"""
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
the model to produce more diverse images, but they may not be as aligned with the image prompt.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.set_ip_adapter_scale(0.6)
>>> ...
```
Args:
scale (float):
IP-Adapter scale to be set.
"""
for attn_processor in self.transformer.attn_processors.values():
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
attn_processor.scale = scale
def unload_ip_adapter(self) -> None:
"""
Unloads the IP Adapter weights.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# Remove image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=None)
# Remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=None)
# Remove image projection
self.transformer.image_proj = None
# Restore original attention processors layers
attn_procs = {
name: (
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
)
for name, value in self.transformer.attn_processors.items()
}
self.transformer.set_attn_processor(attn_procs)
@@ -643,7 +643,11 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
old_state_dict,
new_state_dict,
old_key,
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
[
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
],
)
if "down" in old_key:
+402 -41
View File
@@ -1863,6 +1863,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer(
@@ -2309,16 +2312,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None
lora_A_weight_name = f"{name}.lora_A.weight"
lora_B_weight_name = f"{name}.lora_B.weight"
if lora_A_weight_name not in state_dict.keys():
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
if lora_A_weight_name not in state_dict:
continue
in_features = state_dict[lora_A_weight_name].shape[1]
@@ -2329,57 +2333,106 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
continue
module_out_features, module_in_features = module_weight.shape
if out_features < module_out_features or in_features < module_in_features:
raise NotImplementedError(
f"Only LoRAs with input/output features higher than the current module's input/output features "
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
debug_message = ""
if in_features > module_in_features:
debug_message += (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
if out_features > module_out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
logger.debug(debug_message)
if debug_message:
logger.debug(debug_message)
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
if out_features > module_out_features or in_features > module_in_features:
has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
# TODO: consider initializing this under meta device for optims.
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
expanded_module.weight.data.copy_(new_weight)
if module_bias is not None:
expanded_module.bias.data.copy_(module_bias)
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
# explains the reason why only weights are expanded.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
tmp_state_dict["bias"] = module_bias
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
setattr(parent_module, current_module_name, expanded_module)
setattr(parent_module, current_module_name, expanded_module)
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
del tmp_state_dict
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(expanded_module.weight.data.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)
return has_param_with_shape_update
@classmethod
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
expanded_module_names = set()
transformer_state_dict = transformer.state_dict()
prefix = f"{cls.transformer_name}."
lora_module_names = [
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
]
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
lora_module_names = sorted(set(lora_module_names))
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for k in lora_module_names:
if k in unexpected_modules:
continue
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
if base_weight_param.shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
if expanded_module_names:
logger.info(
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
)
return lora_state_dict
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
@@ -3870,6 +3923,314 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`HunyuanVideoTransformer3DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
+1
View File
@@ -53,6 +53,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
}
@@ -32,6 +32,7 @@ from .single_file_utils import (
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
@@ -96,6 +97,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
"MochiTransformer3DModel": {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
+125 -9
View File
@@ -99,13 +99,15 @@ CHECKPOINT_KEY_NAMES = {
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"ltx-video": [
(
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
),
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.per_channel_statistics.mean-of-means",
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -151,12 +153,15 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
}
# Use to configure model sample size when original config is provided
@@ -587,11 +592,17 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
model_type = "flux-dev"
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint["img_in.weight"].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
else:
model_type = "flux-schnell"
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
@@ -610,6 +621,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "autoencoder-dc-f128c512"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
model_type = "mochi-1-preview"
else:
model_type = "v1"
@@ -1750,6 +1764,12 @@ def swap_scale_shift(weight, dim):
return new_weight
def swap_proj_gate(weight):
proj, gate = weight.chunk(2, dim=0)
new_weight = torch.cat([gate, proj], dim=0)
return new_weight
def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
@@ -2247,9 +2267,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
TRANSFORMER_KEYS_RENAME_DICT = {
"model.diffusion_model.": "",
@@ -2406,3 +2424,101 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict = {}
# Comfy checkpoints add this prefix
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
# Visual attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
)
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
return new_state_dict
+179
View File
@@ -0,0 +1,179 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
if is_accelerate_available():
pass
logger = logging.get_logger(__name__)
class FluxTransformer2DLoadersMixin:
"""
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 0
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
num_image_text_embed = 4
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
num_image_text_embed = 16
# IP-Adapter
num_image_text_embeds += [num_image_text_embed]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
dtype=self.dtype,
device=self.device,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = self.device
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
class SD3Transformer2DLoadersMixin:
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
Args:
state_dict (`Dict`):
State dict with keys "ip_adapter", which contains parameters for attention processors, and
"image_proj", which contains parameters for image projection net.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# IP-Adapter cross attention parameters
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
# Dict where key is transformer layer index, value is attention processor's state dict
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
for key, weights in state_dict["ip_adapter"].items():
idx, name = key.split(".", maxsplit=1)
layer_state_dict[int(idx)][name] = weights
# Create IP-Adapter attention processor
attn_procs = {}
for idx, name in enumerate(self.attn_processors.keys()):
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size,
ip_hidden_states_dim=ip_hidden_states_dim,
head_dim=self.config.attention_head_dim,
timesteps_emb_dim=timesteps_emb_dim,
).to(self.device, dtype=self.dtype)
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
)
self.set_attn_processor(attn_procs)
# Image projetion parameters
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
num_queries = state_dict["image_proj"]["latents"].shape[1]
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
# Image projection
self.image_proj = IPAdapterTimeImageProjection(
embed_dim=embed_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
heads=heads,
num_queries=num_queries,
timestep_in_dim=timestep_in_dim,
).to(device=self.device, dtype=self.dtype)
if not low_cpu_mem_usage:
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
else:
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
+10 -3
View File
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
self._chunk_dim = dim
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
+317 -1
View File
@@ -575,7 +575,7 @@ class Attention(nn.Module):
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
@@ -2653,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU:
return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -5243,6 +5386,177 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
return hidden_states
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
additional image-based information and timestep embeddings.
Args:
hidden_size (`int`):
The number of hidden channels.
ip_hidden_states_dim (`int`):
The image feature dimension.
head_dim (`int`):
The number of head channels.
timesteps_emb_dim (`int`, defaults to 1280):
The number of input channels for timestep embedding.
scale (`float`, defaults to 0.5):
IP-Adapter scale.
"""
def __init__(
self,
hidden_size: int,
ip_hidden_states_dim: int,
head_dim: int,
timesteps_emb_dim: int = 1280,
scale: float = 0.5,
):
super().__init__()
# To prevent circular import
from .normalization import AdaLayerNorm, RMSNorm
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.norm_q = RMSNorm(head_dim, 1e-6)
self.norm_k = RMSNorm(head_dim, 1e-6)
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
self.scale = scale
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
ip_hidden_states: torch.FloatTensor = None,
temb: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
Args:
attn (`Attention`):
Attention instance.
hidden_states (`torch.FloatTensor`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
The encoder hidden states.
attention_mask (`torch.FloatTensor`, *optional*):
Attention mask.
ip_hidden_states (`torch.FloatTensor`, *optional*):
Image embeddings.
temb (`torch.FloatTensor`, *optional*):
Timestep embeddings.
Returns:
`torch.FloatTensor`: Output hidden states.
"""
residual = hidden_states
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
img_query = query
img_key = key
img_value = value
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP Adapter
if self.scale != 0 and ip_hidden_states is not None:
# Norm image features
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
# To k and v
ip_key = self.to_k_ip(norm_ip_hidden_states)
ip_value = self.to_v_ip(norm_ip_hidden_states)
# Reshape
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Norm
query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
ip_key = self.norm_ip_k(ip_key)
# cat img
key = torch.cat([img_key, ip_key], dim=2)
value = torch.cat([img_value, ip_value], dim=2)
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + ip_hidden_states * self.scale
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -5725,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = (
SlicedAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
)
AttentionProcessor = Union[
@@ -5772,6 +6087,7 @@ AttentionProcessor = Union[
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
LoRAAttnProcessor,
@@ -168,6 +168,7 @@ class HunyuanVideoResnetBlockCausal3D(nn.Module):
self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.contiguous()
residual = hidden_states
hidden_states = self.norm1(hidden_states)
@@ -792,12 +793,12 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
self.tile_sample_min_num_frames = 64
self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_sample_stride_num_frames = 48
self.tile_sample_stride_num_frames = 12
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
@@ -1003,7 +1004,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
@@ -1020,7 +1021,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
+183 -2
View File
@@ -691,7 +691,7 @@ class CogVideoXPatchEmbed(nn.Module):
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = torch.zeros(
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
@@ -1535,7 +1535,7 @@ class ImageProjection(nn.Module):
batch_size = image_embeds.shape[0]
# image
image_embeds = self.image_embeds(image_embeds)
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
@@ -2396,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
return out
class IPAdapterTimeImageProjectionBlock(nn.Module):
"""Block for IPAdapterTimeImageProjection.
Args:
hidden_dim (`int`, defaults to 1280):
The number of hidden channels.
dim_head (`int`, defaults to 64):
The number of head channels.
heads (`int`, defaults to 20):
Parallel attention heads.
ffn_ratio (`int`, defaults to 4):
The expansion ratio of feedforward network hidden layer channels.
"""
def __init__(
self,
hidden_dim: int = 1280,
dim_head: int = 64,
heads: int = 20,
ffn_ratio: int = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
self.attn = Attention(
query_dim=hidden_dim,
cross_attention_dim=hidden_dim,
dim_head=dim_head,
heads=heads,
bias=False,
out_bias=False,
)
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
# AdaLayerNorm
self.adaln_silu = nn.SiLU()
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
self.adaln_norm = nn.LayerNorm(hidden_dim)
# Set attention scale and fuse KV
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
self.attn.fuse_projections()
self.attn.to_k = None
self.attn.to_v = None
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x (`torch.Tensor`):
Image features.
latents (`torch.Tensor`):
Latent features.
timestep_emb (`torch.Tensor`):
Timestep embedding.
Returns:
`torch.Tensor`: Output latent features.
"""
# Shift and scale for AdaLayerNorm
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
# Fused Attention
residual = latents
x = self.ln0(x)
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
batch_size = latents.shape[0]
query = self.attn.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.attn.heads
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
latents = weight @ value
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
latents = self.attn.to_out[0](latents)
latents = self.attn.to_out[1](latents)
latents = latents + residual
## FeedForward
residual = latents
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
return self.ff(latents) + residual
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
class IPAdapterTimeImageProjection(nn.Module):
"""Resampler of SD3 IP-Adapter with timestep embedding.
Args:
embed_dim (`int`, defaults to 1152):
The feature dimension.
output_dim (`int`, defaults to 2432):
The number of output channels.
hidden_dim (`int`, defaults to 1280):
The number of hidden channels.
depth (`int`, defaults to 4):
The number of blocks.
dim_head (`int`, defaults to 64):
The number of head channels.
heads (`int`, defaults to 20):
Parallel attention heads.
num_queries (`int`, defaults to 64):
The number of queries.
ffn_ratio (`int`, defaults to 4):
The expansion ratio of feedforward network hidden layer channels.
timestep_in_dim (`int`, defaults to 320):
The number of input channels for timestep embedding.
timestep_flip_sin_to_cos (`bool`, defaults to True):
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
timestep_freq_shift (`int`, defaults to 0):
Controls the timestep delta between frequencies between dimensions.
"""
def __init__(
self,
embed_dim: int = 1152,
output_dim: int = 2432,
hidden_dim: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 20,
num_queries: int = 64,
ffn_ratio: int = 4,
timestep_in_dim: int = 320,
timestep_flip_sin_to_cos: bool = True,
timestep_freq_shift: int = 0,
) -> None:
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
self.proj_in = nn.Linear(embed_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList(
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
)
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Args:
x (`torch.Tensor`):
Image features.
timestep (`torch.Tensor`):
Timestep in denoising process.
Returns:
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
"""
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
timestep_emb = self.time_embedding(timestep_emb)
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
x = x + timestep_emb[:, None]
for block in self.layers:
latents = block(x, latents, timestep_emb)
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents, timestep_emb
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
+32 -14
View File
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
return next(parameter.parameters()).dtype
except StopIteration:
try:
return next(parameter.buffers()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
if param.is_floating_point():
return param.dtype
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
for buffer in parameter.buffers():
last_dtype = buffer.dtype
if buffer.is_floating_point():
return buffer.dtype
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -802,7 +820,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None:
if hf_quantizer is not None and is_bnb_quantization_method:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
@@ -21,7 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attn_output, context_attn_output = self.attn(
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
):
"""
The Transformer model introduced in Flux.
@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -19,7 +19,8 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..embeddings import (
@@ -32,6 +33,9 @@ from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideoAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -496,7 +500,47 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -630,8 +674,24 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -717,6 +777,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
@@ -20,6 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
@@ -304,7 +305,7 @@ class MochiRoPE(nn.Module):
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
@@ -334,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["MochiTransformerBlock"]
@register_to_config
def __init__(
@@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
@@ -103,7 +103,9 @@ class SD3SingleTransformerBlock(nn.Module):
return hidden_states
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
"""
The Transformer model introduced in Stable Diffusion 3.
@@ -349,8 +351,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states (`list` of `torch.Tensor`):
@@ -390,6 +392,12 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False
@@ -411,11 +419,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
+8 -1
View File
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
conditioning with `class_embed_type` equal to `None`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
+75 -8
View File
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(hidden_states)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
+2
View File
@@ -128,6 +128,7 @@ else:
]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -539,6 +540,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
+32 -5
View File
@@ -35,9 +35,12 @@ from .controlnet import (
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
]
)
@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
]
)
@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
if "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
else:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin):
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "Img2Img" in orig_class_name:
to_replace = "Img2ImgPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin):
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "Inpaint" in orig_class_name:
to_replace = "InpaintPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin):
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
return hidden_states
+2
View File
@@ -26,6 +26,7 @@ else:
_import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
_import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux import FluxPipeline
from .pipeline_flux_control import FluxControlPipeline
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
+173 -5
View File
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,6 +149,7 @@ class FluxPipeline(
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
FluxIPAdapterMixin,
):
r"""
The Flux pipeline for text-to-image generation.
@@ -169,8 +177,8 @@ class FluxPipeline(
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
@@ -182,6 +190,8 @@ class FluxPipeline(
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
@@ -193,6 +203,8 @@ class FluxPipeline(
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -377,14 +389,60 @@ class FluxPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
image_embeds.append(single_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -419,10 +477,33 @@ class FluxPipeline(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -551,6 +632,9 @@ class FluxPipeline(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
@@ -561,6 +645,12 @@ class FluxPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -610,6 +700,17 @@ class FluxPipeline(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -647,8 +748,12 @@ class FluxPipeline(
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -670,6 +775,7 @@ class FluxPipeline(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
@@ -684,6 +790,21 @@ class FluxPipeline(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -725,12 +846,43 @@ class FluxPipeline(
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -746,6 +898,22 @@ class FluxPipeline(
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -403,7 +403,6 @@ class FluxControlPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
def check_inputs(
self,
prompt,
File diff suppressed because it is too large Load Diff
@@ -1095,7 +1095,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# predict the noise residual
if self.controlnet.config.guidance_embeds:
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds
if use_guidance:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
@@ -20,6 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
@@ -132,7 +133,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class HunyuanVideoPipeline(DiffusionPipeline):
class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
@@ -142,7 +143,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
Args:
text_encoder ([`LlamaModel`]):
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
tokenizer_2 (`LlamaTokenizer`):
tokenizer (`LlamaTokenizer`):
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
transformer ([`HunyuanVideoTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
@@ -447,6 +448,10 @@ class HunyuanVideoPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@@ -471,6 +476,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -525,6 +531,10 @@ class HunyuanVideoPipeline(DiffusionPipeline):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -562,6 +572,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
device = self._execution_device
@@ -640,6 +651,7 @@ class HunyuanVideoPipeline(DiffusionPipeline):
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
@@ -188,6 +188,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()
@@ -205,10 +206,11 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
)
self.default_height = 480
self.default_width = 848
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
def _get_t5_prompt_embeds(
self,
@@ -236,7 +238,11 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
if prompt == "" or prompt[-1] == "":
# The original Mochi implementation zeros out empty negative prompts
# but this can lead to overflow when placing the entire pipeline under the autocast context
# adding this here so that we can enable zeroing prompts if necessary
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
+22
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
@@ -811,6 +812,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
expected_types = pipeline_class._get_signature_types()
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -833,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for key in init_dict.keys():
if key not in passed_class_obj:
continue
if "scheduler" in key:
continue
class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if not _is_valid_type:
logger.warning(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
@@ -446,7 +446,7 @@ class StableAudioPipeline(DiffusionPipeline):
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
# check num_channels
@@ -255,7 +255,12 @@ class StableDiffusionPipeline(
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
is_unet_sample_size_less_64 = (
hasattr(unet.config, "sample_size")
and self._is_unet_config_sample_size_int
and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -902,8 +907,18 @@ class StableDiffusionPipeline(
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if not height or not width:
height = (
self.unet.config.sample_size
if self._is_unet_config_sample_size_int
else self.unet.config.sample_size[0]
)
width = (
self.unet.config.sample_size
if self._is_unet_config_sample_size_int
else self.unet.config.sample_size[1]
)
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
@@ -1,4 +1,4 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,7 +144,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
_optional_components = []
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -191,6 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()
@@ -204,6 +212,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -683,6 +693,83 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
def interrupt(self):
return self._interrupt
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
"""Encodes the given image into a feature representation using a pre-trained image encoder.
Args:
image (`PipelineImageInput`):
Input image to be encoded.
device: (`torch.device`):
Torch device.
Returns:
`torch.Tensor`: The encoded image feature representation.
"""
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=self.dtype)
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
device = device or self._execution_device
if ip_adapter_image_embeds is not None:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
else:
single_image_embeds = ip_adapter_image_embeds
elif ip_adapter_image is not None:
single_image_embeds = self.encode_image(ip_adapter_image, device)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
else:
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
return image_embeds.to(device=device)
def enable_sequential_cpu_offload(self, *args, **kwargs):
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
logger.warning(
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
)
super().enable_sequential_cpu_offload(*args, **kwargs)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -705,6 +792,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -713,9 +802,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None,
skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_start: float = 0.01,
mu: Optional[float] = None,
):
r"""
@@ -781,6 +870,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -938,7 +1032,22 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
# 6. Prepare image embeddings
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
else:
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -289,6 +289,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -291,14 +291,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
@@ -318,6 +318,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -381,6 +381,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
@@ -392,6 +392,21 @@ class FluxControlImg2ImgPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class FluxControlInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+29 -38
View File
@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
if not local_files_only:
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
elif local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cache_dir, subfolder)
cached_folder = os.path.join(cached_folder, subfolder)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
return cached_folder, sharded_metadata
-4
View File
@@ -29,7 +29,6 @@ from diffusers import (
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
is_torch_version,
require_peft_backend,
skip_mps,
@@ -37,9 +36,6 @@ from diffusers.utils.testing_utils import (
)
if is_peft_available():
pass
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+144 -42
View File
@@ -340,21 +340,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
pipe.load_lora_weights(lora_state_dict, "adapter-1")
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -430,10 +415,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_lora_expanding_shape_with_normal_lora_raises_error(self):
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
def test_lora_expanding_shape_with_normal_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
@@ -478,27 +463,18 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
# input features before expansion. This should raise an error about the weight shapes being incompatible.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
# We should have `adapter-1` as the only adapter.
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-2")
# Check if the output is the same after lora loading error
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
# This should raise a runtime error on input shapes being incompatible.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
@@ -521,14 +497,11 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
@@ -546,6 +519,98 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"adapter-2",
)
def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
# tested with it.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
components["transformer"] = transformer
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
def test_load_regular_lora(self):
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@@ -760,3 +825,40 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora_with_turbo(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
)
else:
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
)
image = self.pipeline(
prompt=self.prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=self.num_inference_steps,
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = image[0, -3:, -3:, -1].flatten()
if "Canny" in lora_ckpt_id:
expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
else:
expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
+228
View File
@@ -0,0 +1,228 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_torch_version,
require_peft_backend,
skip_mps,
torch_device,
)
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend
@skip_mps
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
}
transformer_cls = HunyuanVideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 4,
"down_block_types": (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
"up_block_types": (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
"block_out_channels": (8, 8, 8, 8),
"layers_per_block": 1,
"act_fn": "silu",
"norm_num_groups": 4,
"scaling_factor": 0.476986,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 4,
"mid_block_add_attention": True,
}
vae_cls = AutoencoderKLHunyuanVideo
has_two_text_encoders = True
tokenizer_cls, tokenizer_id, tokenizer_subfolder = (
LlamaTokenizerFast,
"hf-internal-testing/tiny-random-hunyuanvideo",
"tokenizer",
)
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = (
CLIPTokenizer,
"hf-internal-testing/tiny-random-hunyuanvideo",
"tokenizer_2",
)
text_encoder_cls, text_encoder_id, text_encoder_subfolder = (
LlamaModel,
"hf-internal-testing/tiny-random-hunyuanvideo",
"text_encoder",
)
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = (
CLIPTextModel,
"hf-internal-testing/tiny-random-hunyuanvideo",
"text_encoder_2",
)
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 16
num_channels = 4
num_frames = 9
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
sizes = (4, 4)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "",
"num_frames": num_frames,
"num_inference_steps": 1,
"guidance_scale": 6.0,
"height": 32,
"width": 32,
"max_sequence_length": sequence_length,
"prompt_template": {"template": "{}", "crop_start": 0},
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=True,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(
prompt=inputs["prompt"],
height=inputs["height"],
width=inputs["width"],
num_frames=inputs["num_frames"],
num_inference_steps=inputs["num_inference_steps"],
max_sequence_length=inputs["max_sequence_length"],
output_type="np",
)[0]
self.assertTrue(np.isnan(out).all())
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
# TODO(aryan): Fix the following test
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
def test_simple_inference_save_pretrained(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
-4
View File
@@ -23,7 +23,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
is_torch_version,
require_peft_backend,
skip_mps,
@@ -31,9 +30,6 @@ from diffusers.utils.testing_utils import (
)
if is_peft_available():
pass
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
-4
View File
@@ -29,7 +29,6 @@ from diffusers import (
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
@@ -37,9 +36,6 @@ from diffusers.utils.testing_utils import (
)
if is_peft_available():
pass
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
+22 -12
View File
@@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests:
has_two_text_encoders = False
has_three_text_encoders = False
text_encoder_cls, text_encoder_id = None, None
text_encoder_2_cls, text_encoder_2_id = None, None
text_encoder_3_cls, text_encoder_3_id = None, None
tokenizer_cls, tokenizer_id = None, None
tokenizer_2_cls, tokenizer_2_id = None, None
tokenizer_3_cls, tokenizer_3_id = None, None
text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
unet_kwargs = None
transformer_cls = None
@@ -124,16 +124,26 @@ class PeftLoraLoaderMixinTests:
torch.manual_seed(0)
vae = self.vae_cls(**self.vae_kwargs)
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
text_encoder = self.text_encoder_cls.from_pretrained(
self.text_encoder_id, subfolder=self.text_encoder_subfolder
)
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
if self.text_encoder_2_cls is not None:
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
)
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
)
if self.text_encoder_3_cls is not None:
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
)
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
)
text_lora_config = LoraConfig(
r=rank,
@@ -43,10 +43,14 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
"down_block_types": (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
"up_block_types": (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
"block_out_channels": (8, 8, 8, 8),
"layers_per_block": 1,
@@ -154,6 +158,27 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# We need to overwrite this test because the base test does not account length of down_block_types
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 16, 16, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@@ -146,7 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder"}
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
@@ -65,7 +65,7 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder"}
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test unsupported.")
+3 -1
View File
@@ -803,7 +803,7 @@ class ModelTesterMixin:
self.assertFalse(model.is_gradient_checkpointing)
@require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
@@ -850,6 +850,8 @@ class ModelTesterMixin:
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
@@ -18,6 +18,8 @@ import unittest
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
@@ -26,6 +28,56 @@ from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
+42
View File
@@ -105,6 +105,23 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
"AttnDownBlock2D",
"UNetMidBlock2D",
"UpBlock2D",
"DownBlock2D",
}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 8
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -220,6 +237,17 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 32
block_out_channels = (32, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -329,3 +357,17 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_forward_with_norm_groups(self):
# not required for this model
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
block_out_channels = (32, 64, 64, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, block_out_channels=block_out_channels
)
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
+2
View File
@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
+113 -1
View File
@@ -16,13 +16,14 @@ from diffusers.utils.testing_utils import (
)
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
@@ -91,6 +92,8 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -296,3 +299,112 @@ class FluxPipelineSlowTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
@slow
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
weight_name = "ip_adapter.safetensors"
ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
prompt_embeds = torch.load(
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
)
pooled_prompt_embeds = torch.load(
hf_hub_download(
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
)
)
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8)
return {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
"ip_adapter_image": ip_adapter_image,
"num_inference_steps": 2,
"guidance_scale": 3.5,
"true_cfg_scale": 4.0,
"max_sequence_length": 256,
"output_type": "np",
"generator": generator,
}
def test_flux_ip_adapter_inference(self):
pipe = self.pipeline_class.from_pretrained(
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
)
pipe.load_ip_adapter(
self.ip_adapter_repo_id,
weight_name=self.weight_name,
image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path,
)
pipe.set_ip_adapter_scale(1.0)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.1855,
0.1680,
0.1406,
0.1953,
0.1699,
0.1465,
0.2012,
0.1738,
0.1484,
0.2051,
0.1797,
0.1523,
0.2012,
0.1719,
0.1445,
0.2070,
0.1777,
0.1465,
0.2090,
0.1836,
0.1484,
0.2129,
0.1875,
0.1523,
0.2090,
0.1816,
0.1484,
0.2110,
0.1836,
0.1543,
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
@@ -0,0 +1,215 @@
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxControlInpaintPipeline,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxControlInpaintPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=8,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = Image.new("RGB", (8, 8), 0)
control_image = Image.new("RGB", (8, 8), 0)
mask_image = Image.new("RGB", (8, 8), 255)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": control_image,
"generator": generator,
"image": image,
"mask_image": mask_image,
"strength": 0.8,
"num_inference_steps": 2,
"guidance_scale": 30.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
# def test_flux_different_prompts(self):
# pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
# inputs = self.get_dummy_inputs(torch_device)
# output_same_prompt = pipe(**inputs).images[0]
# inputs = self.get_dummy_inputs(torch_device)
# inputs["prompt_2"] = "a different prompt"
# output_different_prompts = pipe(**inputs).images[0]
# max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# # Outputs should be different here
# # For some reasons, they don't show large differences
# assert max_diff > 1e-6
def test_flux_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
+1 -1
View File
@@ -275,7 +275,7 @@ class MochiPipelineIntegrationTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def test_cogvideox(self):
def test_mochi(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16)
@@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests(
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
def test_pipeline_accept_tuple_type_unet_sample_size(self):
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
sample_size = [60, 80]
customised_unet = UNet2DConditionModel(sample_size=sample_size)
pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
assert pipe.unet.config.sample_size == sample_size
@slow
@require_torch_gpu
@@ -103,6 +103,8 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
+10
View File
@@ -1802,6 +1802,16 @@ class PipelineFastTests(unittest.TestCase):
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5
def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
)
assert "is of type" in str(error_context.exception)
assert "but should be" in str(error_context.exception)
@slow
@require_torch_gpu
+90 -1
View File
@@ -29,7 +29,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -54,6 +54,7 @@ from ..models.autoencoders.vae import (
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
@@ -483,6 +484,94 @@ class IPAdapterTesterMixin:
)
class FluxIPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""
def test_pipeline_signature(self):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
assert issubclass(self.pipeline_class, FluxIPAdapterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter_image_embeds",
parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)
def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
return torch.randn((1, 1, image_embed_dim), device=torch_device)
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice
adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
+43 -19
View File
@@ -278,13 +278,14 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
def test_offload(self):
def test_device_map(self):
"""
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
that the device map is correctly set (in the `hf_device_map` attribute of the model).
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
"""
device_map_offload = {
custom_device_map_dict = {
"time_text_embed": torch_device,
"context_embedder": torch_device,
"x_embedder": torch_device,
@@ -293,27 +294,50 @@ class TorchAoTest(unittest.TestCase):
"norm_out": torch_device,
"proj_out": "cpu",
}
device_maps = ["auto", custom_device_map_dict]
inputs = self.get_dummy_tensor_inputs(torch_device)
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map_offload,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
for device_map in device_maps:
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
# Test non-sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])