Compare commits
108 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 560fb5f4d6 | |||
| 8ab26ac9bf | |||
| 9f305e7ce2 | |||
| 2c25bf5bef | |||
| 0e14cacffc | |||
| 13ea83f0fa | |||
| 2b432ac5a8 | |||
| 263b973466 | |||
| a663a67ea2 | |||
| 526858c801 | |||
| b3d2dd36d7 | |||
| abfa922410 | |||
| 6a7b01f60f | |||
| e8aacda762 | |||
| 12184f4015 | |||
| 6e1d2da194 | |||
| 11b1151840 | |||
| cd4d0d8ffb | |||
| 4b557132ce | |||
| 851dfa30ae | |||
| ea1ba0ba53 | |||
| 9d27df8071 | |||
| 055d95543a | |||
| 71cc2013fe | |||
| c34fc34563 | |||
| 5fcee4a447 | |||
| 76e2727b5c | |||
| 02c777c065 | |||
| 6a970a45c5 | |||
| ffc0eaab6d | |||
| 3c2e2aa8a9 | |||
| b58868e6f4 | |||
| da21d590b5 | |||
| 7c2f0afb1c | |||
| f615f00f58 | |||
| 6aaa0518e3 | |||
| 233dffdc3f | |||
| be2070991f | |||
| bf9a641f1a | |||
| a756694bf0 | |||
| d41388145e | |||
| a6288a5571 | |||
| 7d4db57037 | |||
| 902008608a | |||
| c8ee4af228 | |||
| b64ca6c11c | |||
| e12d610faa | |||
| bf6eaa8aec | |||
| 17128c42a4 | |||
| dbc1d505f0 | |||
| 151b74cd77 | |||
| 41ba8c0bf6 | |||
| 3191248472 | |||
| 648d968cfc | |||
| b756ec6e80 | |||
| d8825e7697 | |||
| 074798b299 | |||
| 3ee966950b | |||
| 9764f229d4 | |||
| 1826a1e7d3 | |||
| 0ed09a17bb | |||
| 2f7a417d1f | |||
| 4450d26b63 | |||
| f781b8c30c | |||
| 9c0e20de61 | |||
| f35a38725b | |||
| f66bd3261c | |||
| c4c99c3907 | |||
| 862a7d5038 | |||
| 8304adce2a | |||
| b389f339ec | |||
| e222246b4e | |||
| 83709d5a06 | |||
| 8eb73c872a | |||
| 88b015dc9f | |||
| 63cdf9c0ba | |||
| 0ac52d6f09 | |||
| ba6fd6eb30 | |||
| 9408aa2dfc | |||
| ec1c7a793f | |||
| 9c68c945e9 | |||
| 2739241ad1 | |||
| 1524781b88 | |||
| 128b96f369 | |||
| e24941b2a7 | |||
| f9d5a9324d | |||
| ac86393487 | |||
| 0d96a894a7 | |||
| 6fb94d51cb | |||
| 7667cfcb41 | |||
| 9f00c617a0 | |||
| aafed3f8dd | |||
| 5ed761a6f2 | |||
| 2f023d7b84 | |||
| e9a3911b67 | |||
| 7186bb45f0 | |||
| 438bd60549 | |||
| 87e8157437 | |||
| 3f421fe09f | |||
| a7d50524dd | |||
| 672bd49573 | |||
| ea893a9ae7 | |||
| 5fb3a98517 | |||
| aace1f412b | |||
| 8957324363 | |||
| e68092a471 | |||
| 3bf5400a64 | |||
| 02cbe972c3 |
@@ -357,6 +357,10 @@ jobs:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -165,7 +165,8 @@ jobs:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults:
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
${CONDA_RUN} python -m uv pip install transformers --upgrade
|
||||
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Test installing diffusers and importing
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
pip install -i https://test.pypi.org/simple/ diffusers
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
|
||||
@@ -157,6 +157,10 @@
|
||||
title: Getting Started
|
||||
- local: quantization/bitsandbytes
|
||||
title: bitsandbytes
|
||||
- local: quantization/gguf
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
@@ -234,6 +238,8 @@
|
||||
title: Textual Inversion
|
||||
- local: api/loaders/unet
|
||||
title: UNet
|
||||
- local: api/loaders/transformer_sd3
|
||||
title: SD3Transformer2D
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
title: Loaders
|
||||
@@ -270,6 +276,8 @@
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
@@ -316,6 +324,8 @@
|
||||
title: AutoencoderKLAllegro
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
@@ -392,8 +402,12 @@
|
||||
title: DiT
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/pix2pix
|
||||
@@ -415,7 +429,7 @@
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTX
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
|
||||
@@ -15,40 +15,135 @@ specific language governing permissions and limitations under the License.
|
||||
An attention processor is a class for applying different types of attention mechanisms.
|
||||
|
||||
## AttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.AttnProcessor
|
||||
|
||||
## AttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||
|
||||
## AttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
|
||||
|
||||
## AttnAddedKVProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
|
||||
|
||||
## CrossFrameAttnProcessor
|
||||
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
|
||||
[[autodoc]] models.attention_processor.AttnProcessorNPU
|
||||
|
||||
## CustomDiffusionAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
## CustomDiffusionAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
|
||||
|
||||
## CustomDiffusionXFormersAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
|
||||
|
||||
## FusedAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||
|
||||
## Allegro
|
||||
|
||||
[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0
|
||||
|
||||
## AuraFlow
|
||||
|
||||
[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0
|
||||
|
||||
## CogVideoX
|
||||
|
||||
[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0
|
||||
|
||||
## CrossFrameAttnProcessor
|
||||
|
||||
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
|
||||
|
||||
## Custom Diffusion
|
||||
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
|
||||
|
||||
## Flux
|
||||
|
||||
[[autodoc]] models.attention_processor.FluxAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0
|
||||
|
||||
## Hunyuan
|
||||
|
||||
[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0
|
||||
|
||||
## IdentitySelfAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
|
||||
|
||||
## JointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0
|
||||
|
||||
## LoRA
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
|
||||
|
||||
## Lumina-T2X
|
||||
|
||||
[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0
|
||||
|
||||
## Mochi
|
||||
|
||||
[[autodoc]] models.attention_processor.MochiAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0
|
||||
|
||||
## Sana
|
||||
|
||||
[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0
|
||||
|
||||
## Stable Audio
|
||||
|
||||
[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0
|
||||
|
||||
## SlicedAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
## SlicedAttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
|
||||
|
||||
## XFormersAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.XFormersAttnProcessor
|
||||
|
||||
## AttnProcessorNPU
|
||||
[[autodoc]] models.attention_processor.AttnProcessorNPU
|
||||
[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor
|
||||
|
||||
## XLAFlashAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0
|
||||
|
||||
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
|
||||
## SD3IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
|
||||
- all
|
||||
- is_ip_adapter_active
|
||||
|
||||
## IPAdapterMaskProcessor
|
||||
|
||||
[[autodoc]] image_processor.IPAdapterMaskProcessor
|
||||
@@ -17,6 +17,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
|
||||
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
|
||||
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
|
||||
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
|
||||
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
|
||||
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
@@ -38,6 +41,18 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
|
||||
|
||||
## FluxLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
## Mochi1LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# SD3Transformer2D
|
||||
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
|
||||
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
||||
|
||||
<Tip>
|
||||
|
||||
To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## SD3Transformer2DLoadersMixin
|
||||
|
||||
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
- all
|
||||
- _load_ip_adapter_weights
|
||||
@@ -29,6 +29,8 @@ The following DCAE models are released and supported in Diffusers.
|
||||
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
|
||||
|
||||
This model was contributed by [lawrence-cj](https://github.com/lawrence-cj).
|
||||
|
||||
Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
|
||||
|
||||
```python
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanVideo
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanVideo
|
||||
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanVideo
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanVideo
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTXVideo
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HunyuanVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] HunyuanVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTXVideoTransformer3DModel
|
||||
|
||||
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaTransformer2DModel
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
<!--Copyright 2024 The HuggingFace Team, The Black Forest Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FluxControlInpaint
|
||||
|
||||
FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
|
||||
|
||||
FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
|
||||
|
||||
| Control type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
|
||||
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxControlInpaintPipeline
|
||||
from diffusers.models.transformers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
pipe = FluxControlInpaintPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Depth-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# use following lines if you have GPU constraints
|
||||
# ---------------------------------------------------------------
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.enable_model_cpu_offload()
|
||||
# ---------------------------------------------------------------
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "a blue robot singing opera with human-like expressions"
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
head_mask = np.zeros_like(image)
|
||||
head_mask[65:580,300:642] = 255
|
||||
mask_image = Image.fromarray(head_mask)
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(image)[0].convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=30,
|
||||
strength=0.9,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
|
||||
```
|
||||
|
||||
## FluxControlInpaintPipeline
|
||||
[[autodoc]] FluxControlInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## FluxPipelineOutput
|
||||
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
|
||||
@@ -268,6 +268,47 @@ images = pipe(
|
||||
images[0].save("flux-redux.png")
|
||||
```
|
||||
|
||||
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
|
||||
|
||||
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
|
||||
|
||||
```py
|
||||
from diffusers import FluxControlPipeline
|
||||
from image_gen_aux import DepthPreprocessor
|
||||
from diffusers.utils import load_image
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
|
||||
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
|
||||
control_pipe.load_lora_weights(
|
||||
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
|
||||
)
|
||||
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
|
||||
control_pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(control_image)[0].convert("RGB")
|
||||
|
||||
image = control_pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Note about `unload_lora_weights()` when using Flux LoRAs
|
||||
|
||||
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# HunyuanVideo
|
||||
|
||||
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
|
||||
|
||||
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
Recommendations for inference:
|
||||
- Both text encoders should be in `torch.float16`.
|
||||
- Transformer should be in `torch.bfloat16`.
|
||||
- VAE should be in `torch.float16`.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
|
||||
|
||||
## HunyuanVideoPipeline
|
||||
|
||||
[[autodoc]] HunyuanVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX
|
||||
# LTX Video
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
@@ -22,23 +22,37 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
single_file_url, torch_dtype=torch.bfloat16
|
||||
)
|
||||
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
|
||||
pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe = LTXImageToVideoPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# ... inference code ...
|
||||
```
|
||||
|
||||
Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`].
|
||||
Alternatively, the pipeline can be used to load the weights with [`~FromSingleFileMixin.from_single_file`].
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -46,11 +60,85 @@ from diffusers import LTXImageToVideoPipeline
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16)
|
||||
pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = T5Tokenizer.from_pretrained(
|
||||
"Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = LTXImageToVideoPipeline.from_single_file(
|
||||
single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
|
||||
)
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = LTXPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=704,
|
||||
height=480,
|
||||
num_frames=161,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
```
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
<!-- TODO(aryan): Update this when official weights are supported -->
|
||||
|
||||
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=768,
|
||||
height=512,
|
||||
num_frames=161,
|
||||
decode_timestep=0.03,
|
||||
decode_noise_scale=0.025,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# Mochi
|
||||
# Mochi 1 Preview
|
||||
|
||||
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
|
||||
|
||||
@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
## Generating videos with Mochi-1 Preview
|
||||
|
||||
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(prompt, num_frames=85).frames[0]
|
||||
|
||||
export_to_video(frames, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Using a lower precision variant to save memory
|
||||
|
||||
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
frames = pipe(prompt, num_frames=85).frames[0]
|
||||
|
||||
export_to_video(frames, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Reproducing the results from the Genmo Mochi repo
|
||||
|
||||
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
|
||||
|
||||
<Tip>
|
||||
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
|
||||
|
||||
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
|
||||
pipe.enable_vae_tiling()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
|
||||
pipe.encode_prompt(prompt=prompt)
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16):
|
||||
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||
frames = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
guidance_scale=4.5,
|
||||
num_inference_steps=64,
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=163,
|
||||
generator=torch.Generator("cuda").manual_seed(0),
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
video_processor = VideoProcessor(vae_scale_factor=8)
|
||||
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
frames = frames / pipe.vae.config.scaling_factor
|
||||
|
||||
with torch.no_grad():
|
||||
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
|
||||
|
||||
video = video_processor.postprocess_video(video)[0]
|
||||
export_to_video(video, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Running inference with multiple GPUs
|
||||
|
||||
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "genmo/mochi-1-preview"
|
||||
transformer = MochiTransformer3DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
device_map="auto",
|
||||
max_memory={0: "24GB", 1: "24GB"}
|
||||
)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(
|
||||
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
|
||||
negative_prompt="",
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=85,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_videos_per_prompt=1,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
max_sequence_length=256,
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Using single file loading with the Mochi Transformer
|
||||
|
||||
You can use `from_single_file` to load the Mochi transformer in its original format.
|
||||
|
||||
<Tip>
|
||||
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "genmo/mochi-1-preview"
|
||||
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
|
||||
|
||||
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(
|
||||
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
|
||||
negative_prompt="",
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=85,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_videos_per_prompt=1,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
max_sequence_length=256,
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## MochiPipeline
|
||||
|
||||
[[autodoc]] MochiPipeline
|
||||
|
||||
@@ -26,15 +26,15 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model]https://huggingface.co/Efficient-Large-Model).
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
|
||||
@@ -42,6 +42,8 @@ Available models:
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
|
||||
|
||||
@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
|
||||
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
|
||||
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
|
||||
|
||||
## Image Prompting with IP-Adapters
|
||||
|
||||
An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
|
||||
|
||||
- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
|
||||
- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
|
||||
- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
|
||||
|
||||
IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import SiglipVisionModel, SiglipImageProcessor
|
||||
|
||||
image_encoder_id = "google/siglip-so400m-patch14-384"
|
||||
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
|
||||
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
image_encoder_id,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
image_encoder_id,
|
||||
torch_dtype=torch.float16
|
||||
).to( "cuda")
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
torch_dtype=torch.float16,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
).to("cuda")
|
||||
|
||||
pipe.load_ip_adapter(ip_adapter_id)
|
||||
pipe.set_ip_adapter_scale(0.6)
|
||||
|
||||
ref_img = Image.open("image.jpg").convert('RGB')
|
||||
|
||||
image = pipe(
|
||||
width=1024,
|
||||
height=1024,
|
||||
prompt="a cat",
|
||||
negative_prompt="lowres, low quality, worst quality",
|
||||
num_inference_steps=24,
|
||||
guidance_scale=5.0,
|
||||
ip_adapter_image=ref_img
|
||||
).images[0]
|
||||
|
||||
image.save("result.jpg")
|
||||
```
|
||||
|
||||
<div class="justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/>
|
||||
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption>
|
||||
</div>
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Memory Optimisations for SD3
|
||||
|
||||
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
|
||||
### Running Inference with Model Offloading
|
||||
|
||||
|
||||
@@ -28,6 +28,13 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
|
||||
[[autodoc]] BitsAndBytesConfig
|
||||
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
## DiffusersQuantizer
|
||||
|
||||
[[autodoc]] quantizers.base.DiffusersQuantizer
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# GGUF
|
||||
|
||||
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
|
||||
|
||||
The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
|
||||
|
||||
Before starting please install gguf in your environment
|
||||
|
||||
```shell
|
||||
pip install -U gguf
|
||||
```
|
||||
|
||||
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
|
||||
|
||||
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
|
||||
|
||||
The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade).
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
- BF16
|
||||
- Q4_0
|
||||
- Q4_1
|
||||
- Q5_0
|
||||
- Q5_1
|
||||
- Q8_0
|
||||
- Q2_K
|
||||
- Q3_K
|
||||
- Q4_K
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
|
||||
@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
|
||||
|
||||
<Tip>
|
||||
|
||||
Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be
|
||||
|
||||
## When to use what?
|
||||
|
||||
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# torchao
|
||||
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
|
||||
|
||||
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
|
||||
|
||||
```bash
|
||||
pip install -U torch torchao
|
||||
```
|
||||
|
||||
|
||||
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
|
||||
|
||||
The example below only quantizes the weights to int8.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id,
|
||||
transformer=transformer,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Without quantization: ~31.447 GB
|
||||
# With quantization: ~20.40 GB
|
||||
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
|
||||
|
||||
```python
|
||||
# In the above code, add the following after initializing the transformer
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
```
|
||||
|
||||
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
|
||||
|
||||
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
|
||||
|
||||
The `TorchAoConfig` class accepts three parameters:
|
||||
- `quant_type`: A string value mentioning one of the quantization types below.
|
||||
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
|
||||
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
|
||||
|
||||
## Supported quantization types
|
||||
|
||||
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
|
||||
|
||||
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
|
||||
|
||||
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
|
||||
|
||||
The quantization methods supported are as follows:
|
||||
|
||||
| **Category** | **Full Function Names** | **Shorthands** |
|
||||
|--------------|-------------------------|----------------|
|
||||
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
|
||||
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
|
||||
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
|
||||
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
|
||||
|
||||
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
|
||||
|
||||
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
|
||||
## Serializing and Deserializing quantized models
|
||||
|
||||
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
|
||||
```
|
||||
|
||||
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
|
||||
# ...
|
||||
|
||||
# Load the model
|
||||
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
|
||||
with init_empty_weights():
|
||||
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
|
||||
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
|
||||
@@ -56,7 +56,7 @@ image
|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
|
||||
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
|
||||
|
||||
You can also merge different adapter checkpoints for inference to blend their styles together.
|
||||
|
||||
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
|
||||
Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
|
||||
|
||||
```python
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
|
||||
> [!TIP]
|
||||
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
|
||||
|
||||
To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
|
||||
```python
|
||||
pipe.set_adapters("toy")
|
||||
@@ -127,7 +127,7 @@ image = pipe(
|
||||
image
|
||||
```
|
||||
|
||||
Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
|
||||
Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
|
||||
|
||||
```python
|
||||
pipe.disable_lora()
|
||||
@@ -140,7 +140,8 @@ image
|
||||

|
||||
|
||||
### Customize adapters strength
|
||||
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
|
||||
|
||||
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
|
||||
|
||||
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
|
||||
```python
|
||||
@@ -195,7 +196,7 @@ image
|
||||
|
||||

|
||||
|
||||
## Manage active adapters
|
||||
## Manage adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
|
||||
|
||||
@@ -212,3 +213,11 @@ list_adapters_component_wise = pipe.get_list_adapters()
|
||||
list_adapters_component_wise
|
||||
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
|
||||
```
|
||||
|
||||
The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
|
||||
|
||||
```py
|
||||
pipe.delete_adapters("toy")
|
||||
pipe.get_active_adapters()
|
||||
["pixel"]
|
||||
```
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
check_min_version("0.32.0")
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -1008,6 +1008,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
self.transformer.inner_dim // self.transformer.num_heads,
|
||||
grid_crops_coords,
|
||||
(grid_height, grid_width),
|
||||
device=device,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
style = torch.tensor([0], device=device)
|
||||
|
||||
@@ -129,7 +129,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if isinstance(prompt, list) else [prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
|
||||
n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
|
||||
if use_base:
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
# DreamBooth training example for SANA
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629).
|
||||
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sana.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sana-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sana.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
For using `push_to_hub`, make you're logged into your Hugging Face account:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
## Notes
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55).
|
||||
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
|
||||
|
||||
|
||||
We provide several options for optimizing memory optimization:
|
||||
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=1.0.0
|
||||
torchvision
|
||||
transformers>=4.47.0
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.14.0
|
||||
sentencepiece
|
||||
@@ -0,0 +1,206 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
|
||||
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
|
||||
|
||||
def test_dreambooth_lora_sana(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# `self.transformer_layer_type` should be in the state dict.
|
||||
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 166
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
resume_run_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import safetensors.torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
vision = True
|
||||
else:
|
||||
vision = False
|
||||
|
||||
"""
|
||||
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
|
||||
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
|
||||
--filename "flux-ip-adapter.safetensors"
|
||||
--output_path "flux-ip-adapter-hf/"
|
||||
"""
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--output_path", type=str)
|
||||
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
|
||||
converted_state_dict = {}
|
||||
|
||||
# image_proj
|
||||
## norm
|
||||
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
||||
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
||||
## proj
|
||||
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
||||
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"ip_adapter.{i}."
|
||||
# to_k_ip
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
|
||||
)
|
||||
# to_v_ip
|
||||
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args)
|
||||
|
||||
num_layers = 19
|
||||
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
|
||||
|
||||
print("Saving Flux IP-Adapter in Diffusers format.")
|
||||
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
|
||||
|
||||
if vision:
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
|
||||
model.save_pretrained(f"{args.output_path}/image_encoder")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -0,0 +1,257 @@
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
||||
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
||||
|
||||
elif "linear1.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
||||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(ckpt_path: str):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = HunyuanVideoTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLHunyuanVideo()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
|
||||
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
|
||||
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
assert args.text_encoder_2_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
||||
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
pipe = HunyuanVideoPipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
@@ -21,7 +23,9 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"vae": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# decoder
|
||||
@@ -54,10 +58,31 @@ VAE_KEYS_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
@@ -80,13 +105,16 @@ def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = ""
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -97,16 +125,21 @@ def convert_transformer(
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "vae."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTXVideo(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -139,6 +222,9 @@ def get_args():
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -161,6 +247,7 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
@@ -169,13 +256,14 @@ if __name__ == "__main__":
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
|
||||
config = get_vae_config(args.version)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
|
||||
@@ -87,13 +88,18 @@ def main(args):
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
layer_num = 28
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -175,6 +181,7 @@ def main(args):
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -265,9 +272,9 @@ if __name__ == "__main__":
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024],
|
||||
choices=[512, 1024, 2048],
|
||||
required=False,
|
||||
help="Image size of pretrained model, 512 or 1024.",
|
||||
help="Image size of pretrained model, 512, 1024 or 2048.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.32.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.32.0.dev0"
|
||||
__version__ = "0.32.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -31,7 +31,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig"],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -84,6 +84,7 @@ else:
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
@@ -102,6 +103,7 @@ else:
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"LatteTransformer3DModel",
|
||||
@@ -275,6 +277,7 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
"FluxControlNetInpaintPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
@@ -287,6 +290,7 @@ else:
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanVideoPipeline",
|
||||
"I2VGenXLPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
@@ -566,7 +570,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
@@ -590,6 +594,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
@@ -608,6 +613,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
LatteTransformer3DModel,
|
||||
@@ -760,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
@@ -772,6 +779,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
I2VGenXLPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
|
||||
@@ -55,7 +55,8 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
|
||||
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
@@ -65,13 +66,20 @@ if is_torch_available():
|
||||
"StableDiffusionLoraLoaderMixin",
|
||||
"SD3LoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
"LTXVideoLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
@@ -79,17 +87,26 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import IPAdapterMixin
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
SanaLoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
|
||||
@@ -33,15 +33,20 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
JointAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -348,3 +353,519 @@ class IPAdapterMixin:
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||
for key in f.keys():
|
||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||
diffusers_name = ".".join(key.split(".")[1:])
|
||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||
diffusers_name = (
|
||||
".".join(key.split(".")[1:])
|
||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||
.replace("processor.", "")
|
||||
)
|
||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_pretrained_model_name_or_path is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||
image_encoder = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_pretrained_model_name_or_path,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
.to(self.device, dtype=image_encoder_dtype)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
|
||||
def LinearStrengthModel(start, finish, size):
|
||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||
|
||||
|
||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
transformer = self.transformer
|
||||
if not isinstance(scale, list):
|
||||
scale = [[scale] * transformer.config.num_layers]
|
||||
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
|
||||
if len(scale) != transformer.config.num_layers:
|
||||
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
|
||||
scale = [scale]
|
||||
|
||||
scale_configs = scale
|
||||
|
||||
key_id = 0
|
||||
for attn_name, attn_processor in transformer.attn_processors.items():
|
||||
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
attn_processor.scale[i] = scale_config[key_id]
|
||||
key_id += 1
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.transformer.encoder_hid_proj = None
|
||||
self.transformer.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class SD3IPAdapterMixin:
|
||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||
|
||||
@property
|
||||
def is_ip_adapter_active(self) -> bool:
|
||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||
|
||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||
the image context is irrelevant.
|
||||
|
||||
Returns:
|
||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||
"""
|
||||
scales = [
|
||||
attn_proc.scale
|
||||
for attn_proc in self.transformer.attn_processors.values()
|
||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||
]
|
||||
|
||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# Load the main state dict first
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
# Commons args for loading image encoder and image processor
|
||||
kwargs = {
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# Load IP-Adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||
"""
|
||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
scale (float):
|
||||
IP-Adapter scale to be set.
|
||||
|
||||
"""
|
||||
for attn_processor in self.transformer.attn_processors.values():
|
||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self) -> None:
|
||||
"""
|
||||
Unloads the IP Adapter weights.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# Remove image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=None)
|
||||
|
||||
# Remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=None)
|
||||
|
||||
# Remove image projection
|
||||
self.transformer.image_proj = None
|
||||
|
||||
# Restore original attention processors layers
|
||||
attn_procs = {
|
||||
name: (
|
||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||
)
|
||||
for name, value in self.transformer.attn_processors.items()
|
||||
}
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
@@ -43,6 +50,8 @@ from ..utils import (
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
@@ -297,6 +306,152 @@ def _best_guess_weight_name(
|
||||
return weight_name
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
@@ -327,27 +482,7 @@ class LoraBaseMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
|
||||
@@ -643,7 +643,11 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
|
||||
],
|
||||
)
|
||||
|
||||
if "down" in old_key:
|
||||
@@ -969,3 +973,178 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
||||
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
|
||||
|
||||
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
|
||||
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
|
||||
# sure that both follow the same initial format by stripping off the "transformer." prefix.
|
||||
for key in list(converted_state_dict.keys()):
|
||||
if key.startswith("transformer."):
|
||||
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
|
||||
if key.startswith("diffusion_model."):
|
||||
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
|
||||
|
||||
# Rename and remap the state dict keys
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
# Add back the "transformer." prefix
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
+1185
-634
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
@@ -30,20 +29,16 @@ from ..utils import (
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_base import _fetch_state_dict
|
||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
@@ -53,6 +48,9 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -137,27 +135,7 @@ class PeftAdapterMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -17,8 +17,10 @@ import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
@@ -26,10 +28,12 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
@@ -94,6 +98,14 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -213,7 +225,10 @@ class FromOriginalModelMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
@@ -227,6 +242,12 @@ class FromOriginalModelMixin:
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
@@ -282,7 +303,7 @@ class FromOriginalModelMixin:
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
revision=config_revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
@@ -309,8 +330,36 @@ class FromOriginalModelMixin:
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device=param_device,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
@@ -324,7 +373,11 @@ class FromOriginalModelMixin:
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"sd3": [
|
||||
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
],
|
||||
"sd35_large": [
|
||||
"joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
],
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
@@ -93,13 +99,16 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
],
|
||||
"ltx-video": [
|
||||
(
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
),
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
"patchify_proj.weight",
|
||||
"transformer_blocks.27.scale_shift_table",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -145,12 +154,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -542,13 +556,20 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "stable_cascade_stage_b"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
|
||||
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
|
||||
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
|
||||
):
|
||||
if "model.diffusion_model.pos_embed" in checkpoint:
|
||||
key = "model.diffusion_model.pos_embed"
|
||||
else:
|
||||
key = "pos_embed"
|
||||
|
||||
if checkpoint[key].shape[1] == 36864:
|
||||
model_type = "sd3"
|
||||
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
|
||||
elif checkpoint[key].shape[1] == 147456:
|
||||
model_type = "sd35_medium"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
|
||||
model_type = "sd35_large"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
@@ -574,12 +595,25 @@ def infer_diffusers_model_type(checkpoint):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
model_type = "flux-dev"
|
||||
if "model.diffusion_model.img_in.weight" in checkpoint:
|
||||
key = "model.diffusion_model.img_in.weight"
|
||||
else:
|
||||
key = "img_in.weight"
|
||||
|
||||
if checkpoint[key].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
elif checkpoint[key].shape[1] == 128:
|
||||
model_type = "flux-depth"
|
||||
else:
|
||||
model_type = "flux-dev"
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
model_type = "ltx-video"
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
encoder_key = "encoder.project_in.conv.conv.bias"
|
||||
@@ -597,6 +631,12 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "autoencoder-dc-f128c512"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
||||
model_type = "mochi-1-preview"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
||||
model_type = "hunyuan-video"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1737,6 +1777,12 @@ def swap_scale_shift(weight, dim):
|
||||
return new_weight
|
||||
|
||||
|
||||
def swap_proj_gate(weight):
|
||||
proj, gate = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([gate, proj], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def get_attn2_layers(state_dict):
|
||||
attn2_layers = []
|
||||
for key in state_dict.keys():
|
||||
@@ -2234,9 +2280,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {
|
||||
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
|
||||
}
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"model.diffusion_model.": "",
|
||||
@@ -2302,12 +2346,32 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
@@ -2393,3 +2457,231 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict = {}
|
||||
|
||||
# Comfy checkpoints add this prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Convert patch_embed
|
||||
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Convert time_embed
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
|
||||
# Convert transformer blocks
|
||||
num_layers = 48
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
old_prefix = f"blocks.{i}."
|
||||
|
||||
# norm1
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
|
||||
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
else:
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
|
||||
# Context attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_y.weight"
|
||||
)
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
|
||||
|
||||
# Output layers
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
||||
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
||||
|
||||
elif "linear1.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
||||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
def update_state_dict_(state_dict, old_key, new_key):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(checkpoint, key, new_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict
|
||||
|
||||
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict["ip_adapter"].items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
||||
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
||||
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
|
||||
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = state_dict["image_proj"]["latents"].shape[1]
|
||||
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
self.image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
|
||||
@@ -21,7 +21,6 @@ import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
@@ -44,13 +43,11 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -400,27 +397,7 @@ class UNet2DConditionLoadersMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
|
||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
@@ -67,6 +68,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
@@ -97,6 +99,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
@@ -130,6 +133,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DualTransformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
|
||||
@@ -164,3 +164,15 @@ class ApproximateGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LinearActivation(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return self.activation(hidden_states)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
@@ -1222,6 +1229,8 @@ class FeedForward(nn.Module):
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "linear-silu":
|
||||
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
|
||||
@@ -254,14 +254,22 @@ class Attention(nn.Module):
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
else:
|
||||
self.add_q_proj = None
|
||||
self.add_k_proj = None
|
||||
self.add_v_proj = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "fp32_layer_norm":
|
||||
@@ -567,7 +575,7 @@ class Attention(nn.Module):
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
@@ -782,7 +790,11 @@ class Attention(nn.Module):
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
|
||||
# handle added projections for SD3 and others.
|
||||
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
|
||||
if (
|
||||
getattr(self, "add_q_proj", None) is not None
|
||||
and getattr(self, "add_k_proj", None) is not None
|
||||
and getattr(self, "add_v_proj", None) is not None
|
||||
):
|
||||
concatenated_weights = torch.cat(
|
||||
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||
)
|
||||
@@ -894,6 +906,177 @@ class SanaMultiscaleLinearAttention(nn.Module):
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class MochiAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
added_kv_proj_dim: int,
|
||||
processor: "MochiAttnProcessor2_0",
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_proj_bias: bool = True,
|
||||
out_dim: Optional[int] = None,
|
||||
out_context_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
context_pre_only: bool = False,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
from .normalization import MochiRMSNorm
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.norm_q = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_k = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MochiAttnProcessor2_0:
|
||||
"""Attention processor used in Mochi."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "MochiAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
||||
x_even = x[..., 0::2].float()
|
||||
x_odd = x[..., 1::2].float()
|
||||
|
||||
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
||||
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
||||
|
||||
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
||||
|
||||
query = apply_rotary_emb(query, *image_rotary_emb)
|
||||
key = apply_rotary_emb(key, *image_rotary_emb)
|
||||
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
encoder_query, encoder_key, encoder_value = (
|
||||
encoder_query.transpose(1, 2),
|
||||
encoder_key.transpose(1, 2),
|
||||
encoder_value.transpose(1, 2),
|
||||
)
|
||||
|
||||
sequence_length = query.size(2)
|
||||
encoder_sequence_length = encoder_query.size(2)
|
||||
total_length = sequence_length + encoder_sequence_length
|
||||
|
||||
batch_size, heads, _, dim = query.shape
|
||||
attn_outputs = []
|
||||
for idx in range(batch_size):
|
||||
mask = attention_mask[idx][None, :]
|
||||
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
|
||||
|
||||
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
|
||||
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
|
||||
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
|
||||
|
||||
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
|
||||
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
|
||||
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
valid_sequence_length = attn_output.size(2)
|
||||
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
|
||||
attn_outputs.append(attn_output)
|
||||
|
||||
hidden_states = torch.cat(attn_outputs, dim=0)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
@@ -2470,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = None
|
||||
# for ip-adapter
|
||||
# TODO: support for multiple adapters
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_attn_output = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_attn_output = scale * ip_attn_output
|
||||
ip_attn_output = ip_attn_output.to(ip_query.dtype)
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@@ -3856,94 +4182,6 @@ class LuminaAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiAttnProcessor2_0:
|
||||
"""Attention processor used in Mochi."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
||||
x_even = x[..., 0::2].float()
|
||||
x_odd = x[..., 1::2].float()
|
||||
|
||||
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
||||
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
||||
|
||||
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
||||
|
||||
query = apply_rotary_emb(query, *image_rotary_emb)
|
||||
key = apply_rotary_emb(key, *image_rotary_emb)
|
||||
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
encoder_query, encoder_key, encoder_value = (
|
||||
encoder_query.transpose(1, 2),
|
||||
encoder_key.transpose(1, 2),
|
||||
encoder_value.transpose(1, 2),
|
||||
)
|
||||
|
||||
sequence_length = query.size(2)
|
||||
encoder_sequence_length = encoder_query.size(2)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=2)
|
||||
key = torch.cat([key, encoder_key], dim=2)
|
||||
value = torch.cat([value, encoder_value], dim=2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
||||
@@ -5148,6 +5386,177 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
"""
|
||||
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
|
||||
additional image-based information and timestep embeddings.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The number of hidden channels.
|
||||
ip_hidden_states_dim (`int`):
|
||||
The image feature dimension.
|
||||
head_dim (`int`):
|
||||
The number of head channels.
|
||||
timesteps_emb_dim (`int`, defaults to 1280):
|
||||
The number of input channels for timestep embedding.
|
||||
scale (`float`, defaults to 0.5):
|
||||
IP-Adapter scale.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ip_hidden_states_dim: int,
|
||||
head_dim: int,
|
||||
timesteps_emb_dim: int = 1280,
|
||||
scale: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import
|
||||
from .normalization import AdaLayerNorm, RMSNorm
|
||||
|
||||
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
|
||||
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
||||
self.norm_q = RMSNorm(head_dim, 1e-6)
|
||||
self.norm_k = RMSNorm(head_dim, 1e-6)
|
||||
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
||||
self.scale = scale
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
ip_hidden_states: torch.FloatTensor = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
|
||||
|
||||
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
|
||||
|
||||
Args:
|
||||
attn (`Attention`):
|
||||
Attention instance.
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
The encoder hidden states.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Attention mask.
|
||||
ip_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Image embeddings.
|
||||
temb (`torch.FloatTensor`, *optional*):
|
||||
Timestep embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Output hidden states.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
img_query = query
|
||||
img_key = key
|
||||
img_value = value
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP Adapter
|
||||
if self.scale != 0 and ip_hidden_states is not None:
|
||||
# Norm image features
|
||||
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
|
||||
|
||||
# To k and v
|
||||
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
||||
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
||||
|
||||
# Reshape
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Norm
|
||||
query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
ip_key = self.norm_ip_k(ip_key)
|
||||
|
||||
# cat img
|
||||
key = torch.cat([img_key, ip_key], dim=2)
|
||||
value = torch.cat([img_value, ip_value], dim=2)
|
||||
|
||||
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + ip_hidden_states * self.scale
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
@@ -5411,21 +5820,37 @@ class SanaMultiscaleAttnProcessor2_0:
|
||||
|
||||
|
||||
class LoRAAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing attention with LoRA.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAXFormersAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing attention with LoRA using xFormers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -5614,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
SlicedAttnProcessor,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
AttentionProcessor = Union[
|
||||
@@ -5640,13 +6066,13 @@ AttentionProcessor = Union[
|
||||
AttnProcessorNPU,
|
||||
AttnProcessor2_0,
|
||||
MochiVaeAttnProcessor2_0,
|
||||
MochiAttnProcessor2_0,
|
||||
StableAudioAttnProcessor2_0,
|
||||
HunyuanAttnProcessor2_0,
|
||||
FusedHunyuanAttnProcessor2_0,
|
||||
PAGHunyuanAttnProcessor2_0,
|
||||
PAGCFGHunyuanAttnProcessor2_0,
|
||||
LuminaAttnProcessor2_0,
|
||||
MochiAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
@@ -5661,6 +6087,7 @@ AttentionProcessor = Union[
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
|
||||
@@ -3,6 +3,7 @@ from .autoencoder_dc import AutoencoderDC
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user