Compare commits

..

86 Commits

Author SHA1 Message Date
yiyixuxu c73c00610e add:
q
2024-12-22 10:21:00 +01:00
Mehmet Yiğit Özgenç 233dffdc3f flux controlnet inpaint config bug (#10291)
* flux controlnet inpaint config bug

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

---------

Co-authored-by: yigitozgenc <yigit@quantuslabs.ai>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-21 18:44:43 +00:00
hlky be2070991f Support Flux IP Adapter (#10261)
* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-21 17:49:58 +00:00
hlky bf9a641f1a Fix EMAModel test_from_pretrained (#10325) 2024-12-21 14:10:44 +00:00
hlky a756694bf0 Fix push_tests_mps.yml (#10326) 2024-12-21 14:10:32 +00:00
Sayak Paul d41388145e [Docs] Update gguf.md to remove generator from the pipeline from_pretrained (#10299)
Update gguf.md to remove generator from the pipeline from_pretrained
2024-12-21 07:15:03 +05:30
Junsong Chen a6288a5571 [Sana]add 2K related model for Sana (#10322)
add 2K related model for Sana
2024-12-20 07:21:34 -10:00
Steven Liu 7d4db57037 [docs] Fix quantization links (#10323)
Update overview.md
2024-12-20 08:30:21 -08:00
Aditya Raj 902008608a [BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)
[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float

Co-authored-by: hlky <hlky@hlky.ac>
2024-12-20 15:29:58 +00:00
Leojc c8ee4af228 docs: fix a mistake in docstring (#10319)
Update pipeline_hunyuan_video.py

docs: fix a mistake
2024-12-20 15:22:32 +00:00
Sayak Paul b64ca6c11c [Docs] Update ltx_video.md to remove generator from from_pretrained() (#10316)
Update ltx_video.md to remove generator from `from_pretrained()`
2024-12-20 18:32:22 +05:30
Dhruv Nair e12d610faa Mochi docs (#9934)
* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-20 16:27:38 +05:30
Sayak Paul bf6eaa8aec [Tests] add integration tests for lora expansion stuff in Flux. (#10318)
add integration tests for lora expansion stuff in Flux.
2024-12-20 16:14:58 +05:30
Sayak Paul 17128c42a4 [LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill (#10259)
* lora expansion with dummy zeros.

* updates

* fix working 🥳

* working.

* use torch.device meta for state dict expansion.

* tests

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>

* fixes

* fixes

* switch to debug

* fix

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

* fix stuff

* docs

---------

Co-authored-by: a-r-r-o-w <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-20 14:30:32 +05:30
Dhruv Nair dbc1d505f0 [Single File] Add GGUF support for LTX (#10298)
* update

* add docs.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-20 11:52:29 +05:30
Aryan 151b74cd77 Make tensors in ResNet contiguous for Hunyuan VAE (#10309)
contiguous tensors in resnet

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-20 11:45:37 +05:30
Aryan 41ba8c0bf6 Add support for sharded models when TorchAO quantization is enabled (#10256)
* add sharded + device_map check
2024-12-19 15:42:20 -10:00
Daniel Regado 3191248472 [WIP] SD3.5 IP-Adapter Pipeline Integration (#9987)
* Added support for single IPAdapter on SD3.5 pipeline



---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-19 14:48:18 -10:00
dg845 648d968cfc Enable Gradient Checkpointing for UNet2DModel (New) (#7201)
* Port UNet2DModel gradient checkpointing code from #6718.


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-19 14:45:45 -10:00
djm b756ec6e80 unet's sample_size attribute is to accept tuple(h, w) in StableDiffusionPipeline (#10181) 2024-12-19 22:24:18 +00:00
Aryan d8825e7697 Fix failing lora tests after HunyuanVideo lora (#10307)
fix
2024-12-20 02:35:41 +05:30
hlky 074798b299 Fix local_files_only for checkpoints with shards (#10294) 2024-12-19 07:04:57 -10:00
Dhruv Nair 3ee966950b Allow Mochi Transformer to be split across multiple GPUs (#10300)
update
2024-12-19 22:34:44 +05:30
Dhruv Nair 9764f229d4 [Single File] Add single file support for Mochi Transformer (#10268)
update
2024-12-19 22:20:40 +05:30
Shenghai Yuan 1826a1e7d3 [LoRA] Support HunyuanVideo (#10254)
* 1217

* 1217

* 1217

* update

* reverse

* add test

* update test

* make style

* update

* make style

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-19 16:22:20 +05:30
hlky 0ed09a17bb Check correct model type is passed to from_pretrained (#10189)
* Check correct model type is passed to `from_pretrained`

* Flax, skip scheduler

* test_wrong_model

* Fix for scheduler

* Update tests/pipelines/test_pipelines.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* EnumMeta

* Flax

* scheduler in expected types

* make

* type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name'

* support union

* fix typing in kandinsky

* make

* add LCMScheduler

* 'LCMScheduler' object has no attribute 'sigmas'

* tests for wrong scheduler

* make

* update

* warning

* tests

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* import FlaxSchedulerMixin

* skip scheduler

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-12-19 09:24:52 +00:00
赵三石 2f7a417d1f Update lora_conversion_utils.py (#9980)
x-flux single-blocks lora load

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-18 23:07:50 -10:00
hlky 4450d26b63 Add Flux Control to AutoPipeline (#10292) 2024-12-18 22:28:56 -10:00
Aryan f781b8c30c Hunyuan VAE tiling fixes and transformer docs (#10295)
* update

* udpate

* fix test
2024-12-19 10:28:10 +05:30
Sayak Paul 9c0e20de61 [chore] Update README_sana.md to update the default model (#10285)
Update README_sana.md to update the default model
2024-12-19 10:24:57 +05:30
Aryan f35a38725b [tests] remove nullop import checks from lora tests (#10273)
remove nullop imports
2024-12-19 01:19:08 +05:30
Aryan f66bd3261c Rename Mochi integration test correctly (#10220)
rename integration test
2024-12-18 22:41:23 +05:30
Aryan c4c99c3907 [tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (#10270)
fix joint pos embedding device
2024-12-18 22:36:08 +05:30
Dhruv Nair 862a7d5038 [Single File] Add single file support for Flux Canny, Depth and Fill (#10288)
update
2024-12-18 19:19:47 +05:30
Dhruv Nair 8304adce2a Make zeroing prompt embeds for Mochi Pipeline configurable (#10284)
update
2024-12-18 18:32:53 +05:30
Dhruv Nair b389f339ec Fix Doc links in GGUF and Quantization overview docs (#10279)
* update

* Update docs/source/en/quantization/gguf.md

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-18 18:32:36 +05:30
hlky e222246b4e Fix sigma_last with use_flow_sigmas (#10267) 2024-12-18 12:22:10 +00:00
Andrés Romero 83709d5a06 Flux Control(Depth/Canny) + Inpaint (#10192)
* flux_control_inpaint - failing test_flux_different_prompts

* removing test_flux_different_prompts?

* fix style

* fix from PR comments

* fix style

* reducing guidance_scale in demo

* Update src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Co-authored-by: hlky <hlky@hlky.ac>

* make

* prepare_latents is not copied from

* update docs

* typos

---------

Co-authored-by: affromero <ubuntu@ip-172-31-17-146.ec2.internal>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-18 09:14:16 +00:00
Qin Zhou 8eb73c872a Support pass kwargs to sd3 custom attention processor (#9818)
* Support pass kwargs to sd3 custom attention processor


---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-17 21:58:33 -10:00
Xinyuan Zhao 88b015dc9f Make time_embed_dim of UNet2DModel changeable (#10262) 2024-12-17 21:55:18 -10:00
Sayak Paul 63cdf9c0ba [chore] fix: reamde -> readme (#10276)
fix: reamde -> readme
2024-12-18 10:56:08 +05:30
hlky 0ac52d6f09 Use torch in get_2d_rotary_pos_embed (#10155)
* Use `torch` in `get_2d_rotary_pos_embed`

* Add deprecation
2024-12-17 18:26:52 -10:00
Sayak Paul ba6fd6eb30 [chore] fix: licensing headers in mochi and ltx (#10275)
fix: licensing header.
2024-12-18 08:43:57 +05:30
Sayak Paul 9408aa2dfc [LoRA] feat: lora support for SANA. (#10234)
* feat: lora support for SANA.

* make fix-copies

* rename test class.

* attention_kwargs -> cross_attention_kwargs.

* Revert "attention_kwargs -> cross_attention_kwargs."

This reverts commit 23433bf9bc.

* exhaust 119 max line limit

* sana lora fine-tuning script.

* readme

* add a note about the supported models.

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

* style

* docs for attention_kwargs.

* remove lora_scale from pag pipeline.

* copy fix

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-18 08:22:31 +05:30
hlky ec1c7a793f Add set_shift to FlowMatchEulerDiscreteScheduler (#10269) 2024-12-17 21:40:09 +00:00
cjkangme 9c68c945e9 [Community Pipeline] Fix typo that cause error on regional prompting pipeline (#10251)
fix: fix typo that cause error
2024-12-17 21:09:50 +00:00
Steven Liu 2739241ad1 [docs] delete_adapters() (#10245)
delete_adapters

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-17 09:26:45 -08:00
Aryan 1524781b88 [tests] Remove/rename unsupported quantization torchao type (#10263)
update
2024-12-17 21:43:15 +05:30
Dhruv Nair 128b96f369 Fix Mochi Quality Issues (#10033)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/models/transformers/transformer_mochi.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-17 19:40:00 +05:30
Dhruv Nair e24941b2a7 [Single File] Add GGUF support (#9964)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/gguf/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/gguf.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-12-17 16:09:37 +05:30
Aryan f9d5a9324d [docs] Clarify dtypes for Sana (#10248)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-17 13:43:24 +05:30
Aryan ac86393487 [LoRA] Support LTX Video (#10228)
* add lora support for ltx

* add tests

* fix copied from comments

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-17 12:05:05 +05:30
Aryan 0d96a894a7 Fix copied from comment in Mochi lora loader (#10255)
update
2024-12-17 11:09:57 +05:30
Sayak Paul 6fb94d51cb [chore] add contribution note for lawrence. (#10253)
add contribution note for lawrence.
2024-12-17 09:17:40 +05:30
Steven Liu 7667cfcb41 [docs] Add missing AttnProcessors (#10246)
* attnprocessors

* lora

* make style

* fix

* fix

* sana

* typo
2024-12-16 15:36:26 -08:00
Aryan 9f00c617a0 [core] TorchAO Quantizer (#10009)
* torchao quantizer


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-12-16 13:35:40 -10:00
Kaiwen Sheng aafed3f8dd fix downsample bug in MidResTemporalBlock1D (#10250) 2024-12-17 04:55:16 +05:30
hlky 5ed761a6f2 Add ControlNetUnion to AutoPipeline from_pretrained (#10219) 2024-12-16 10:25:08 -10:00
hlky 2f023d7b84 Fix RePaint Scheduler (#10185)
Fix repaint scheduler
2024-12-16 09:38:13 -10:00
hlky e9a3911b67 Fix checkpoint in CogView3PlusPipeline example (#10211) 2024-12-16 09:31:22 -10:00
hlky 7186bb45f0 Add enable_vae_tiling to AllegroPipeline, fix example (#10212) 2024-12-16 09:31:02 -10:00
hlky 438bd60549 Use non-human subject in StableDiffusion3ControlNetPipeline example (#10214)
* Use non-human subject in StableDiffusion3ControlNetPipeline example

* make style
2024-12-16 09:30:26 -10:00
hlky 87e8157437 Fix ControlNetUnion _callback_tensor_inputs (#10218) 2024-12-16 09:29:12 -10:00
hlky 3f421fe09f Fix use_flow_sigmas (#10242)
use_flow_sigmas copy
2024-12-16 09:27:22 -10:00
hlky a7d50524dd Add dynamic_shifting to SD3 (#10236)
* Add `dynamic_shifting` to SD3

* calculate_shift

* FlowMatchHeunDiscreteScheduler doesn't support mu

* Inpaint/img2img
2024-12-16 09:25:21 -10:00
hlky 672bd49573 Use t instead of timestep in _apply_perturbed_attention_guidance (#10243) 2024-12-16 09:24:16 -10:00
Sayak Paul ea893a9ae7 [Docs] add rest of the lora loader mixins to the docs. (#10230)
add rest of the lora loader mixins to the docs.
2024-12-16 08:50:27 -08:00
fancy45daddy 5fb3a98517 Update pipeline_controlnet.py add support for pytorch_xla (#10222)
* Update pipeline_controlnet.py

* make style

---------

Co-authored-by: hlky <hlky@hlky.ac>
2024-12-16 09:05:50 +00:00
Aryan aace1f412b [core] Hunyuan Video (#10136)
* copy transformer

* copy vae

* copy pipeline

* make fix-copies

* refactor; make original code work with diffusers; test latents for comparison generated with this commit

* move rope into pipeline; remove flash attention; refactor

* begin conversion script

* make style

* refactor attention

* refactor

* refactor final layer

* their mlp -> our feedforward

* make style

* add docs

* refactor layer names

* refactor modulation

* cleanup

* refactor norms

* refactor activations

* refactor single blocks attention

* refactor attention processor

* make style

* cleanup a bit

* refactor double transformer block attention

* update mochi attn proc

* use diffusers attention implementation in all modules; checkpoint for all values matching original

* remove helper functions in vae

* refactor upsample

* refactor causal conv

* refactor resnet

* refactor

* refactor

* refactor

* grad checkpointing

* autoencoder test

* fix scaling factor

* refactor clip

* refactor llama text encoding

* add coauthor

Co-Authored-By: "Gregory D. Hunkins" <greg@ollano.com>

* refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device

Note: The following line diverges from original behaviour. We create the grid on the device, whereas
original implementation creates it on CPU and then moves it to device. This results in numerical
differences in layerwise debugging outputs, but visually it is the same.

* use diffusers timesteps embedding; diff: 0.10205078125

* rename

* convert

* update

* add tests for transformer

* add pipeline tests; text encoder 2 is not optional

* fix attention implementation for torch

* add example

* update docs

* update docs

* apply suggestions from review

* refactor vae

* update

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Co-authored-by: hlky <hlky@hlky.ac>

* make fix-copies

* update

---------

Co-authored-by: "Gregory D. Hunkins" <greg@ollano.com>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-16 13:56:18 +05:30
Dhruv Nair 8957324363 Fix format issue in push_test yml (#10235)
update
2024-12-16 12:28:36 +05:30
Sayak Paul e68092a471 [docs] minor stuff to ltx video docs. (#10229)
minor stuff to ltx video docs.
2024-12-16 12:24:14 +05:30
Sayak Paul 3bf5400a64 Update sana.md with minor corrections (#10232) 2024-12-16 10:26:06 +05:30
Sayak Paul 02cbe972c3 [Tests] update always test pipelines list. (#10143)
update always test pipelines list.
2024-12-16 08:51:55 +05:30
Junsong Chen 5a196e3d46 [Sana] Add Sana, including SanaPipeline, SanaPAGPipeline, LinearAttentionProcessor, Flow-based DPM-sovler and so on. (#9982)
* first add a script for DC-AE;

* DC-AE init

* replace triton with custom implementation

* 1. rename file and remove un-used codes;

* no longer rely on omegaconf and dataclass

* replace custom activation with diffuers activation

* remove dc_ae attention in attention_processor.py

* iinherit from ModelMixin

* inherit from ConfigMixin

* dc-ae reduce to one file

* update downsample and upsample

* clean code

* support DecoderOutput

* remove get_same_padding and val2tuple

* remove autocast and some assert

* update ResBlock

* remove contents within super().__init__

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove opsequential

* update other blocks to support the removal of build_norm

* remove build encoder/decoder project in/out

* remove inheritance of RMSNorm2d from LayerNorm

* remove reset_parameters for RMSNorm2d

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove device and dtype in RMSNorm2d __init__

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove op_list & build_block

* remove build_stage_main

* change file name to autoencoder_dc

* move LiteMLA to attention.py

* align with other vae decode output;

* add DC-AE into init files;

* update

* make quality && make style;

* quick push before dgx disappears again

* update

* make style

* update

* update

* fix

* refactor

* refactor

* refactor

* update

* possibly change to nn.Linear

* refactor

* make fix-copies

* replace vae with ae

* replace get_block_from_block_type to get_block

* replace downsample_block_type from Conv to conv for consistency

* add scaling factors

* incorporate changes for all checkpoints

* make style

* move mla to attention processor file; split qkv conv to linears

* refactor

* add tests

* from original file loader

* add docs

* add standard autoencoder methods

* combine attention processor

* fix tests

* update

* minor fix

* minor fix

* minor fix & in/out shortcut rename

* minor fix

* make style

* fix paper link

* update docs

* update single file loading

* make style

* remove single file loading support; todo for DN6

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add abstract

* 1. add DCAE into diffusers;
2. make style and make quality;

* add DCAE_HF into diffusers;

* bug fixed;

* add SanaPipeline, SanaTransformer2D into diffusers;

* add sanaLinearAttnProcessor2_0;

* first update for SanaTransformer;

* first update for SanaPipeline;

* first success run SanaPipeline;

* model output finally match with original model with the same intput;

* code update;

* code update;

* add a flow dpm-solver scripts

* 🎉[important update]
1. Integrate flow-dpm-sovler into diffusers;
2. finally run successfully on both `FlowMatchEulerDiscreteScheduler` and `FlowDPMSolverMultistepScheduler`;

* 🎉🔧[important update & fix huge bugs!!]
1. add SanaPAGPipeline & several related Sana linear attention operators;
2. `SanaTransformer2DModel` not supports multi-resolution input;
2. fix the multi-scale HW bugs in SanaPipeline and SanaPAGPipeline;
3. fix the flow-dpm-solver set_timestep() init `model_output` and `lower_order_nums` bugs;

* remove prints;

* add convert sana official checkpoint to diffusers format Safetensor.

* Update src/diffusers/models/transformers/sana_transformer_2d.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/pag/pipeline_pag_sana.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update Sana for DC-AE's recent commit;

* make style && make quality

* Add StableDiffusion3PAGImg2Img Pipeline + Fix SD3 Unconditional PAG (#9932)

* fix progress bar updates in SD 1.5 PAG Img2Img pipeline

---------

Co-authored-by: Vinh H. Pham <phamvinh257@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* make the vae can be None in `__init__` of `SanaPipeline`

* Update src/diffusers/models/transformers/sana_transformer_2d.py

Co-authored-by: hlky <hlky@hlky.ac>

* change the ae related code due to the latest update of DCAE branch;

* change the ae related code due to the latest update of DCAE branch;

* 1. change code based on AutoencoderDC;
2. fix the bug of new GLUMBConv;
3. run success;

* update for solving conversation.

* 1. fix bugs and run convert script success;
2. Downloading ckpt from hub automatically;

* make style && make quality;

* 1. remove un-unsed parameters in init;
2. code update;

* remove test file

* refactor; add docs; add tests; update conversion script

* make style

* make fix-copies

* refactor

* udpate pipelines

* pag tests and refactor

* remove sana pag conversion script

* handle weight casting in conversion script

* update conversion script

* add a processor

* 1. add bf16 pth file path;
2. add complex human instruct in pipeline;

* fix fast \tests

* change gemma-2-2b-it ckpt to a non-gated repo;

* fix the pth path bug in conversion script;

* change grad ckpt to original; make style

* fix the complex_human_instruct bug and typo;

* remove dpmsolver flow scheduler

* apply review suggestions

* change the `FlowMatchEulerDiscreteScheduler` to default `DPMSolverMultistepScheduler` with flow matching scheduler.

* fix the tokenizer.padding_side='right' bug;

* update docs

* make fix-copies

* fix imports

* fix docs

* add integration test

* update docs

* update examples

* fix convert_model_output in schedulers

* fix failing tests

---------

Co-authored-by: Junyu Chen <chenjydl2003@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: chenjy2003 <70215701+chenjy2003@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-16 02:16:56 +05:30
Aryan 22c4f079b1 Test error raised when loading normal and expanding loras together in Flux (#10188)
* add test for expanding lora and normal lora error

* Update tests/lora/test_lora_layers_flux.py

* fix things.

* Update src/diffusers/loaders/peft.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-15 21:46:21 +05:30
Junjie 96a9097445 Add offload option in flux-control training (#10225)
* Add offload option in flux-control training

* Update examples/flux-control/train_control_flux.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* modify help message

* fix format

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-15 20:49:17 +05:30
Juan Acevedo a5f35ee473 add reshape to fix use_memory_efficient_attention in flax (#7918)
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-14 17:45:45 +01:00
hlky 63243406ba Use torch in get_2d_sincos_pos_embed and get_3d_sincos_pos_embed (#10156)
* Use torch in get_2d_sincos_pos_embed

* Use torch in get_3d_sincos_pos_embed

* get_1d_sincos_pos_embed_from_grid in LatteTransformer3DModel

* deprecate

* move deprecate, make private
2024-12-13 10:13:38 -10:00
Miguel Farinha 6bd30ba748 Allow image resolutions multiple of 8 instead of 64 in SVD pipeline (#6646)
allow resolutions not multiple of 64 in SVD

Co-authored-by: Miguel Farinha <mignha@CSL15958.local>
Co-authored-by: hlky <hlky@hlky.ac>
2024-12-13 16:17:15 +00:00
Linoy Tsaban cef0e3677e [RF inversion community pipeline] add eta_decay (#10199)
* add decay

* add decay

* style
2024-12-13 11:04:26 +02:00
skotapati ec9bfa9e14 Remove mps workaround for fp16 GELU, which is now supported natively (#10133)
* Remove mps workaround for fp16 GELU, which is now supported natively

---------

Co-authored-by: hlky <hlky@hlky.ac>
2024-12-12 16:05:59 -10:00
Bios bdbaea8f64 update StableDiffusion3Img2ImgPipeline.add image size validation (#10166)
* update StableDiffusion3Img2ImgPipeline.add image size validation

---------

Co-authored-by: hlky <hlky@hlky.ac>
2024-12-12 12:32:18 -10:00
hlky e8b65bffa2 refactor StableDiffusionXLControlNetUnion (#10200)
mode
2024-12-12 12:21:27 -10:00
hlky f2d348d904 Remove negative_* from SDXL callback (#10203)
* Remove `negative_*` from SDXL callback

* Change example and add XL version
2024-12-12 20:58:50 +00:00
Pauline Bailly-Masson c002724dd5 Ci update tpu (#10197)
* Update nightly_tests.yml for TPU CI

* Update push_tests.yml
2024-12-12 23:54:41 +05:30
Aryan 96c376a5ff [core] LTX Video (#10021)
* transformer

* make style & make fix-copies

* transformer

* add transformer tests

* 80% vae

* make style

* make fix-copies

* fix

* undo cogvideox changes

* update

* update

* match vae

* add docs

* t2v pipeline working; scheduler needs to be checked

* docs

* add pipeline test

* update

* update

* make fix-copies

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* copy t2v to i2v pipeline

* update

* apply review suggestions

* update

* make style

* remove framewise encoding/decoding

* pack/unpack latents

* image2video

* update

* make fix-copies

* update

* update

* rope scale fix

* debug layerwise code

* remove debug

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* propagate precision changes to i2v pipeline

* remove downcast

* address review comments

* fix comment

* address review comments

* [Single File] LTX support for loading original weights (#10135)

* from original file mixin for ltx

* undo config mapping fn changes

* update

* add single file to pipelines

* update docs

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

* rename classes based on ltx review

* point to original repository for inference

* make style

* resolve conflicts correctly

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-12-12 16:21:28 +05:30
181 changed files with 24141 additions and 937 deletions
+2
View File
@@ -357,6 +357,8 @@ jobs:
config:
- backend: "bitsandbytes"
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
runs-on:
group: aws-g6e-xlarge-plus
container:
+2 -1
View File
@@ -165,7 +165,8 @@ jobs:
group: gcp-ct5lp-hightpu-8t
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults:
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
steps:
+1 -1
View File
@@ -46,7 +46,7 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
${CONDA_RUN} python -m uv pip install -e [quality,test]
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade
+24
View File
@@ -157,6 +157,10 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
title: Quantization Methods
- sections:
- local: optimization/fp16
@@ -234,6 +238,8 @@
title: Textual Inversion
- local: api/loaders/unet
title: UNet
- local: api/loaders/transformer_sd3
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
title: Loaders
@@ -270,10 +276,14 @@
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
@@ -282,6 +292,8 @@
title: PriorTransformer
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
@@ -312,6 +324,10 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
@@ -386,8 +402,12 @@
title: DiT
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/pix2pix
@@ -408,6 +428,8 @@
title: Latte
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
@@ -428,6 +450,8 @@
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
+113 -18
View File
@@ -15,40 +15,135 @@ specific language governing permissions and limitations under the License.
An attention processor is a class for applying different types of attention mechanisms.
## AttnProcessor
[[autodoc]] models.attention_processor.AttnProcessor
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0
## AttnAddedKVProcessor
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
## AttnAddedKVProcessor2_0
[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
## CrossFrameAttnProcessor
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
[[autodoc]] models.attention_processor.AttnProcessorNPU
## CustomDiffusionAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
## CustomDiffusionAttnProcessor2_0
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
## CustomDiffusionXFormersAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## Allegro
[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0
## AuraFlow
[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0
## CogVideoX
[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0
## CrossFrameAttnProcessor
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
## Custom Diffusion
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
## Flux
[[autodoc]] models.attention_processor.FluxAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0
[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0
## Hunyuan
[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0
## IdentitySelfAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0
## IP-Adapter
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
## JointAttnProcessor2_0
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0
## LoRA
[[autodoc]] models.attention_processor.LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
## Lumina-T2X
[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0
## Mochi
[[autodoc]] models.attention_processor.MochiAttnProcessor2_0
[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0
## Sana
[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0
[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0
[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0
## Stable Audio
[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0
## SlicedAttnProcessor
[[autodoc]] models.attention_processor.SlicedAttnProcessor
## SlicedAttnAddedKVProcessor
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
## XFormersAttnProcessor
[[autodoc]] models.attention_processor.XFormersAttnProcessor
## AttnProcessorNPU
[[autodoc]] models.attention_processor.AttnProcessorNPU
[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor
## XLAFlashAttnProcessor2_0
[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0
+6
View File
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
## SD3IPAdapterMixin
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
- all
- is_ip_adapter_active
## IPAdapterMaskProcessor
[[autodoc]] image_processor.IPAdapterMaskProcessor
+15
View File
@@ -17,6 +17,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
@@ -38,6 +41,18 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
## FluxLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
## Mochi1LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
@@ -0,0 +1,29 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SD3Transformer2D
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
<Tip>
To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
</Tip>
## SD3Transformer2DLoadersMixin
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
- all
- _load_ip_adapter_weights
@@ -29,6 +29,8 @@ The following DCAE models are released and supported in Diffusers.
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
This model was contributed by [lawrence-cj](https://github.com/lawrence-cj).
Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
```python
@@ -0,0 +1,32 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanVideo
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanVideo
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
```
## AutoencoderKLHunyuanVideo
[[autodoc]] AutoencoderKLHunyuanVideo
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -0,0 +1,37 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTXVideo
The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTXVideo
[[autodoc]] AutoencoderKLLTXVideo
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HunyuanVideoTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanVideoTransformer3DModel
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
```
## HunyuanVideoTransformer3DModel
[[autodoc]] HunyuanVideoTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTXVideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTXVideoTransformer3DModel
[[autodoc]] LTXVideoTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,34 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# SanaTransformer2DModel
A Diffusion Transformer model for 2D data from [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) was introduced from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
The abstract from the paper is:
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
```
## SanaTransformer2DModel
[[autodoc]] SanaTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,89 @@
<!--Copyright 2024 The HuggingFace Team, The Black Forest Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FluxControlInpaint
FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
| Control type | Developer | Link |
| -------- | ---------- | ---- |
| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
<Tip>
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
</Tip>
```python
import torch
from diffusers import FluxControlInpaintPipeline
from diffusers.models.transformers import FluxTransformer2DModel
from transformers import T5EncoderModel
from diffusers.utils import load_image, make_image_grid
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
from PIL import Image
import numpy as np
pipe = FluxControlInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
torch_dtype=torch.bfloat16,
)
# use following lines if you have GPU constraints
# ---------------------------------------------------------------
transformer = FluxTransformer2DModel.from_pretrained(
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
# ---------------------------------------------------------------
pipe.to("cuda")
prompt = "a blue robot singing opera with human-like expressions"
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
head_mask = np.zeros_like(image)
head_mask[65:580,300:642] = 255
mask_image = Image.fromarray(head_mask)
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(image)[0].convert("RGB")
output = pipe(
prompt=prompt,
image=image,
control_image=control_image,
mask_image=mask_image,
num_inference_steps=30,
strength=0.9,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
```
## FluxControlInpaintPipeline
[[autodoc]] FluxControlInpaintPipeline
- all
- __call__
## FluxPipelineOutput
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
+37
View File
@@ -268,6 +268,43 @@ images = pipe(
images[0].save("flux-redux.png")
```
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
```py
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = control_pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")
```
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
@@ -0,0 +1,43 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HunyuanVideo
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
Recommendations for inference:
- Both text encoders should be in `torch.float16`.
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
## HunyuanVideoPipeline
[[autodoc]] HunyuanVideoPipeline
- all
- __call__
## HunyuanVideoPipelineOutput
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
+118
View File
@@ -0,0 +1,118 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(
single_file_url, torch_dtype=torch.bfloat16
)
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_pretrained(
"Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16
)
# ... inference code ...
```
Alternatively, the pipeline can be used to load the weights with [`~FromSingleFileMixin.from_single_file`].
```python
import torch
from diffusers import LTXImageToVideoPipeline
from transformers import T5EncoderModel, T5Tokenizer
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
text_encoder = T5EncoderModel.from_pretrained(
"Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16
)
tokenizer = T5Tokenizer.from_pretrained(
"Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16
)
pipe = LTXImageToVideoPipeline.from_single_file(
single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16
)
```
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
```py
import torch
from diffusers.utils import export_to_video
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
)
transformer = LTXVideoTransformer3DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=704,
height=480,
num_frames=161,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
```
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
## LTXPipeline
[[autodoc]] LTXPipeline
- all
- __call__
## LTXImageToVideoPipeline
[[autodoc]] LTXImageToVideoPipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
+196 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
-->
# Mochi
# Mochi 1 Preview
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
## Generating videos with Mochi-1 Preview
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
```python
import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
# Enable memory savings
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
frames = pipe(prompt, num_frames=85).frames[0]
export_to_video(frames, "mochi.mp4", fps=30)
```
## Using a lower precision variant to save memory
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
```python
import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
# Enable memory savings
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
frames = pipe(prompt, num_frames=85).frames[0]
export_to_video(frames, "mochi.mp4", fps=30)
```
## Reproducing the results from the Genmo Mochi repo
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
<Tip>
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
</Tip>
<Tip>
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
</Tip>
```python
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
with torch.no_grad():
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
pipe.encode_prompt(prompt=prompt)
)
with torch.autocast("cuda", torch.bfloat16):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
frames = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
guidance_scale=4.5,
num_inference_steps=64,
height=480,
width=848,
num_frames=163,
generator=torch.Generator("cuda").manual_seed(0),
output_type="latent",
return_dict=False,
)[0]
video_processor = VideoProcessor(vae_scale_factor=8)
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
latents_std = (
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
else:
frames = frames / pipe.vae.config.scaling_factor
with torch.no_grad():
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
video = video_processor.postprocess_video(video)[0]
export_to_video(video, "mochi.mp4", fps=30)
```
## Running inference with multiple GPUs
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
```python
import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
model_id = "genmo/mochi-1-preview"
transformer = MochiTransformer3DModel.from_pretrained(
model_id,
subfolder="transformer",
device_map="auto",
max_memory={0: "24GB", 1: "24GB"}
)
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
frames = pipe(
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
negative_prompt="",
height=480,
width=848,
num_frames=85,
num_inference_steps=50,
guidance_scale=4.5,
num_videos_per_prompt=1,
generator=torch.Generator(device="cuda").manual_seed(0),
max_sequence_length=256,
output_type="pil",
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
## Using single file loading with the Mochi Transformer
You can use `from_single_file` to load the Mochi transformer in its original format.
<Tip>
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
</Tip>
```python
import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
model_id = "genmo/mochi-1-preview"
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
frames = pipe(
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
negative_prompt="",
height=480,
width=848,
num_frames=85,
num_inference_steps=50,
guidance_scale=4.5,
num_videos_per_prompt=1,
generator=torch.Generator(device="cuda").manual_seed(0),
max_sequence_length=256,
output_type="pil",
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
## MochiPipeline
[[autodoc]] MochiPipeline
+67
View File
@@ -0,0 +1,67 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaPipeline
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
The abstract from the paper is:
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model).
Available models:
| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px_diffusers) | `torch.float16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information.
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
<Tip>
Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
</Tip>
## SanaPipeline
[[autodoc]] SanaPipeline
- all
- __call__
## SanaPAGPipeline
[[autodoc]] SanaPAGPipeline
- all
- __call__
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
## Image Prompting with IP-Adapters
An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
```python
import torch
from PIL import Image
from diffusers import StableDiffusion3Pipeline
from transformers import SiglipVisionModel, SiglipImageProcessor
image_encoder_id = "google/siglip-so400m-patch14-384"
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
feature_extractor = SiglipImageProcessor.from_pretrained(
image_encoder_id,
torch_dtype=torch.float16
)
image_encoder = SiglipVisionModel.from_pretrained(
image_encoder_id,
torch_dtype=torch.float16
).to( "cuda")
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
torch_dtype=torch.float16,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
).to("cuda")
pipe.load_ip_adapter(ip_adapter_id)
pipe.set_ip_adapter_scale(0.6)
ref_img = Image.open("image.jpg").convert('RGB')
image = pipe(
width=1024,
height=1024,
prompt="a cat",
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=5.0,
ip_adapter_image=ref_img
).images[0]
image.save("result.jpg")
```
<div class="justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/>
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption>
</div>
<Tip>
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
</Tip>
## Memory Optimisations for SD3
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
### Running Inference with Model Offloading
+7
View File
@@ -28,6 +28,13 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
[[autodoc]] BitsAndBytesConfig
## GGUFQuantizationConfig
[[autodoc]] GGUFQuantizationConfig
## TorchAoConfig
[[autodoc]] TorchAoConfig
## DiffusersQuantizer
[[autodoc]] quantizers.base.DiffusersQuantizer
+69
View File
@@ -0,0 +1,69 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# GGUF
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
Before starting please install gguf in your environment
```shell
pip install -U gguf
```
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade).
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
## Supported Quantization Types
- BF16
- Q4_0
- Q4_1
- Q5_0
- Q5_1
- Q8_0
- Q2_K
- Q3_K
- Q4_K
- Q5_K
- Q6_K
+7 -2
View File
@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
<Tip>
Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
</Tip>
@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
+92
View File
@@ -0,0 +1,92 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# torchao
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
```bash
pip install -U torch torchao
```
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
The example below only quantizes the weights to int8.
```python
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
image.save("output.png")
```
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
```python
# In the above code, add the following after initializing the transformer
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
```
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
The `TorchAoConfig` class accepts three parameters:
- `quant_type`: A string value mentioning one of the quantization types below.
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
## Supported quantization types
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
The quantization methods supported are as follows:
| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
@@ -56,7 +56,7 @@ image
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
```python
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
You can also merge different adapter checkpoints for inference to blend their styles together.
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
```python
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
> [!TIP]
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
```python
pipe.set_adapters("toy")
@@ -127,7 +127,7 @@ image = pipe(
image
```
Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
```python
pipe.disable_lora()
@@ -140,7 +140,8 @@ image
![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)
### Customize adapters strength
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
```python
@@ -195,7 +196,7 @@ image
![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png)
## Manage active adapters
## Manage adapters
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
@@ -212,3 +213,11 @@ list_adapters_component_wise = pipe.get_list_adapters()
list_adapters_component_wise
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
```
The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
```py
pipe.delete_adapters("toy")
pipe.get_active_adapters()
["pixel"]
```
+152 -30
View File
@@ -241,27 +241,15 @@ from diffusers import StableDiffusionPipeline
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers.configuration_utils import register_to_config
import torch
from typing import Any, Dict, Optional
from typing import Any, Dict, Tuple, Union
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
pipeline.safety_checker = None
pipeline.requires_safety_checker = False
class SDPromptScheduleCallback(PipelineCallback):
class SDPromptSchedulingCallback(PipelineCallback):
@register_to_config
def __init__(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_images_per_prompt: int = 1,
cutoff_step_ratio=1.0,
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cutoff_step_ratio=None,
cutoff_step_index=None,
):
super().__init__(
@@ -275,6 +263,10 @@ class SDPromptScheduleCallback(PipelineCallback):
) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
if isinstance(self.config.encoded_prompt, tuple):
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
else:
prompt_embeds = self.config.encoded_prompt
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
@@ -284,34 +276,164 @@ class SDPromptScheduleCallback(PipelineCallback):
)
if step_index == cutoff_step:
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
prompt=self.config.prompt,
negative_prompt=self.config.negative_prompt,
device=pipeline._execution_device,
num_images_per_prompt=self.config.num_images_per_prompt,
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
)
if pipeline.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
return callback_kwargs
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
pipeline.safety_checker = None
pipeline.requires_safety_checker = False
callback = MultiPipelineCallbacks(
[
SDPromptScheduleCallback(
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
negative_prompt="Deformed, ugly, bad anatomy",
cutoff_step_ratio=0.25,
)
SDPromptSchedulingCallback(
encoded_prompt=pipeline.encode_prompt(
prompt=f"prompt {index}",
negative_prompt=f"negative prompt {index}",
device=pipeline._execution_device,
num_images_per_prompt=1,
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
do_classifier_free_guidance=True,
),
cutoff_step_index=index,
) for index in range(1, 20)
]
)
image = pipeline(
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
negative_prompt="Deformed, ugly, bad anatomy",
prompt="prompt"
negative_prompt="negative prompt",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["prompt_embeds"],
).images[0]
torch.cuda.empty_cache()
image.save('image.png')
```
```python
from diffusers import StableDiffusionXLPipeline
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers.configuration_utils import register_to_config
import torch
from typing import Any, Dict, Tuple, Union
class SDXLPromptSchedulingCallback(PipelineCallback):
@register_to_config
def __init__(
self,
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cutoff_step_ratio=None,
cutoff_step_index=None,
):
super().__init__(
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
)
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
def callback_fn(
self, pipeline, step_index, timestep, callback_kwargs
) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
if isinstance(self.config.encoded_prompt, tuple):
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
else:
prompt_embeds = self.config.encoded_prompt
if isinstance(self.config.add_text_embeds, tuple):
add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
else:
add_text_embeds = self.config.add_text_embeds
if isinstance(self.config.add_time_ids, tuple):
add_time_ids, negative_add_time_ids = self.config.add_time_ids
else:
add_time_ids = self.config.add_time_ids
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index
if cutoff_step_index is not None
else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
if pipeline.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs
pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
callbacks = []
for index in range(1, 20):
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
prompt=f"prompt {index}",
negative_prompt=f"prompt {index}",
device=pipeline._execution_device,
num_images_per_prompt=1,
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
do_classifier_free_guidance=True,
)
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
add_time_ids = pipeline._get_add_time_ids(
(1024, 1024),
(0, 0),
(1024, 1024),
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
negative_add_time_ids = pipeline._get_add_time_ids(
(1024, 1024),
(0, 0),
(1024, 1024),
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
callbacks.append(
SDXLPromptSchedulingCallback(
encoded_prompt=(prompt_embeds, negative_prompt_embeds),
add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
add_time_ids=(add_time_ids, negative_add_time_ids),
cutoff_step_index=index,
)
)
callback = MultiPipelineCallbacks(callbacks)
image = pipeline(
prompt="prompt",
negative_prompt="negative prompt",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=[
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
],
).images[0]
```
@@ -648,6 +648,8 @@ class RFInversionFluxPipeline(
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 1.0,
decay_eta: Optional[bool] = False,
eta_decay_power: Optional[float] = 1.0,
strength: float = 1.0,
start_timestep: float = 0,
stop_timestep: float = 0.25,
@@ -880,12 +882,9 @@ class RFInversionFluxPipeline(
v_t = -noise_pred
v_t_cond = (y_0 - latents) / (1 - t_i)
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
if start_timestep <= i < stop_timestep:
# controlled vector field
v_hat_t = v_t + eta * (v_t_cond - v_t)
else:
v_hat_t = v_t
if decay_eta:
eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
v_hat_t = v_t + eta_t * (v_t_cond - v_t)
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])
@@ -1008,6 +1008,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)
style = torch.tensor([0], device=device)
@@ -129,7 +129,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)
if use_base:
+127
View File
@@ -0,0 +1,127 @@
# DreamBooth training example for SANA
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629).
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sana.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
accelerate launch train_dreambooth_lora_sana.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
huggingface-cli login
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
## Notes
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55).
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
@@ -0,0 +1,8 @@
accelerate>=1.0.0
torchvision
transformers>=4.47.0
ftfy
tensorboard
Jinja2
peft>=0.14.0
sentencepiece
File diff suppressed because it is too large Load Diff
+2
View File
@@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
@@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
+10 -3
View File
@@ -541,6 +541,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -999,8 +1004,9 @@ def main(args):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()
if args.offload:
# offload vae to CPU.
vae.cpu()
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -1064,7 +1070,8 @@ def main(args):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Predict.
model_pred = flux_transformer(
@@ -573,6 +573,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1140,8 +1145,10 @@ def main(args):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()
if args.offload:
# offload vae to CPU.
vae.cpu()
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -1205,7 +1212,8 @@ def main(args):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Predict.
model_pred = flux_transformer(
@@ -0,0 +1,97 @@
import argparse
from contextlib import nullcontext
import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
if is_transformers_available():
from transformers import CLIPVisionModelWithProjection
vision = True
else:
vision = False
"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}
# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")
if __name__ == "__main__":
main(args)
@@ -0,0 +1,257 @@
import argparse
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
VAE_KEYS_RENAME_DICT = {}
VAE_SPECIAL_KEYS_REMAP = {}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def convert_transformer(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
with init_empty_weights():
transformer = HunyuanVideoTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
with init_empty_weights():
vae = AutoencoderKLHunyuanVideo()
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
assert args.text_encoder_2_path is not None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+209
View File
@@ -0,0 +1,209 @@
import argparse
from typing import Any, Dict
import torch
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
TOKENIZER_MAX_LENGTH = 128
TRANSFORMER_KEYS_RENAME_DICT = {
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VAE_KEYS_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""
original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
if args.transformer_ckpt_path is not None:
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
+308
View File
@@ -0,0 +1,308 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
import os
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderDC,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaTransformer2DModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
"Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth",
"Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth",
]
# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
def main(args):
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
snapshot_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
cache_dir=cache_dir_path,
repo_type="model",
)
file_path = hf_hub_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
cache_dir=cache_dir_path,
repo_type="model",
)
else:
file_path = args.orig_ckpt_path
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
all_state_dict = torch.load(file_path, weights_only=True)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
flow_shift = 3.0
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Linear Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.point_conv.conv.weight"
)
# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
# Transformer
with CTX():
transformer = SanaTransformer2DModel(
in_channels=32,
out_channels=32,
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
mlp_ratio=2.5,
attention_bias=False,
sample_size=args.image_size // 32,
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
else:
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
transformer = transformer.to(weight_dtype)
if not args.save_full_pipeline:
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole SanaPipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
)
else:
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[512, 1024, 2048],
required=False,
help="Image size of pretrained model, 512, 1024 or 2048.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
)
parser.add_argument(
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args()
model_kwargs = {
"SanaMS_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
"SanaMS_600M_P1_D28": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)
+25 -2
View File
@@ -31,7 +31,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -84,6 +84,8 @@ else:
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
@@ -101,9 +103,11 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
@@ -112,6 +116,7 @@ else:
"MultiControlNetModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaTransformer2DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -272,6 +277,7 @@ else:
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
@@ -284,6 +290,7 @@ else:
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanVideoPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -317,6 +324,8 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXImageToVideoPipeline",
"LTXPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
@@ -328,6 +337,8 @@ else:
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"ReduxImageEncoder",
"SanaPAGPipeline",
"SanaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -341,6 +352,7 @@ else:
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
@@ -558,7 +570,7 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
try:
if not is_onnx_available():
@@ -582,6 +594,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
@@ -599,9 +613,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
@@ -610,6 +626,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiControlNetModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -749,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
@@ -761,6 +779,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanVideoPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
@@ -794,6 +813,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXImageToVideoPipeline,
LTXPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
@@ -805,6 +826,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
ReduxImageEncoder,
SanaPAGPipeline,
SanaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
+20 -3
View File
@@ -55,7 +55,8 @@ _import_structure = {}
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
@@ -65,13 +66,20 @@ if is_torch_available():
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
"Mochi1LoraLoaderMixin",
"HunyuanVideoLoraLoaderMixin",
"SanaLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -79,17 +87,26 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
FluxLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Mochi1LoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
+529 -8
View File
@@ -33,15 +33,20 @@ from .unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available():
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
JointAttnProcessor2_0,
SD3IPAdapterJointAttnProcessor2_0,
)
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)
logger = logging.get_logger(__name__)
@@ -348,3 +353,519 @@ class IPAdapterMixin:
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
Can be either:
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
for key in f.keys():
if any(key.startswith(prefix) for prefix in image_proj_keys):
diffusers_name = ".".join(key.split(".")[1:])
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
diffusers_name = (
".".join(key.split(".")[1:])
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
.replace("processor.", "")
)
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_pretrained_model_name_or_path is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
image_encoder = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_pretrained_model_name_or_path,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.to(self.device, dtype=image_encoder_dtype)
.eval()
)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
def LinearStrengthModel(start, finish, size):
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
transformer = self.transformer
if not isinstance(scale, list):
scale = [[scale] * transformer.config.num_layers]
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
if len(scale) != transformer.config.num_layers:
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
scale = [scale]
scale_configs = scale
key_id = 0
for attn_name, attn_processor in transformer.attn_processors.items():
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
attn_processor.scale[i] = scale_config[key_id]
key_id += 1
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove hidden encoder
self.transformer.encoder_hid_proj = None
self.transformer.config.encoder_hid_dim_type = None
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters."""
@property
def is_ip_adapter_active(self) -> bool:
"""Checks if IP-Adapter is loaded and scale > 0.
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
the image context is irrelevant.
Returns:
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
"""
scales = [
attn_proc.scale
for attn_proc in self.transformer.attn_processors.values()
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
]
return len(scales) > 0 and any(scale > 0 for scale in scales)
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str = "ip-adapter.safetensors",
subfolder: Optional[str] = None,
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
) -> None:
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
weight_name (`str`, defaults to "ip-adapter.safetensors"):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
subfolder (`str`, *optional*):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# Load the main state dict first
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
# Commons args for loading image encoder and image processor
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage,
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# Load IP-Adapter into transformer
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: float) -> None:
"""
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
the model to produce more diverse images, but they may not be as aligned with the image prompt.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.set_ip_adapter_scale(0.6)
>>> ...
```
Args:
scale (float):
IP-Adapter scale to be set.
"""
for attn_processor in self.transformer.attn_processors.values():
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
attn_processor.scale = scale
def unload_ip_adapter(self) -> None:
"""
Unloads the IP Adapter weights.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# Remove image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=None)
# Remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=None)
# Remove image projection
self.transformer.image_proj = None
# Restore original attention processors layers
attn_procs = {
name: (
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
)
for name, value in self.transformer.attn_processors.items()
}
self.transformer.set_attn_processor(attn_procs)
@@ -643,7 +643,11 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
old_state_dict,
new_state_dict,
old_key,
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
[
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
],
)
if "down" in old_key:
File diff suppressed because it is too large Load Diff
+20 -2
View File
@@ -53,6 +53,9 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
}
@@ -205,6 +208,7 @@ class PeftAdapterMixin:
weights.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -316,8 +320,22 @@ class PeftAdapterMixin:
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except RuntimeError as e:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise
warn_msg = ""
if incompatible_keys is not None:
+59 -2
View File
@@ -17,8 +17,10 @@ import re
from contextlib import nullcontext
from typing import Optional
import torch
from huggingface_hub.utils import validate_hf_hub_args
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
@@ -28,6 +30,9 @@ from .single_file_utils import (
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
@@ -83,7 +88,19 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTXVideo": {
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
"MochiTransformer3DModel": {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -204,6 +221,8 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
@@ -217,6 +236,12 @@ class FromOriginalModelMixin:
local_files_only=local_files_only,
revision=revision,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
else:
hf_quantizer = None
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
@@ -299,8 +324,36 @@ class FromOriginalModelMixin:
with ctx():
model = cls.from_config(diffusers_model_config)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]
else:
keep_in_fp32_modules = []
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
param_device = torch.device(device) if device else torch.device("cpu")
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -314,7 +367,11 @@ class FromOriginalModelMixin:
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if torch_dtype is not None:
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)
model.eval()
+236 -7
View File
@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
"sd3": [
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
],
"sd35_large": [
"joint_blocks.37.x_block.mlp.fc1.weight",
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
],
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -92,8 +98,16 @@ CHECKPOINT_KEY_NAMES = {
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"ltx-video": [
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.per_channel_statistics.mean-of-means",
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -139,11 +153,15 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
}
# Use to configure model sample size when original config is provided
@@ -535,13 +553,20 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "stable_cascade_stage_b"
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
):
if "model.diffusion_model.pos_embed" in checkpoint:
key = "model.diffusion_model.pos_embed"
else:
key = "pos_embed"
if checkpoint[key].shape[1] == 36864:
model_type = "sd3"
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
elif checkpoint[key].shape[1] == 147456:
model_type = "sd35_medium"
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
model_type = "sd35_large"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
@@ -567,10 +592,19 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
model_type = "flux-dev"
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint["img_in.weight"].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
else:
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
decoder_key = "decoder.project_in.main.conv.weight"
@@ -587,6 +621,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "autoencoder-dc-f128c512"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
model_type = "mochi-1-preview"
else:
model_type = "v1"
@@ -1727,6 +1764,12 @@ def swap_scale_shift(weight, dim):
return new_weight
def swap_proj_gate(weight):
proj, gate = weight.chunk(2, dim=0)
new_weight = torch.cat([gate, proj], dim=0)
return new_weight
def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
@@ -2223,6 +2266,94 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
return converted_state_dict
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
TRANSFORMER_KEYS_RENAME_DICT = {
"model.diffusion_model.": "",
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
def remove_keys_(key: str, state_dict):
state_dict.pop(key)
VAE_KEYS_RENAME_DICT = {
# common
"vae.": "",
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
@@ -2293,3 +2424,101 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict = {}
# Comfy checkpoints add this prefix
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
# Visual attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
)
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
return new_state_dict
+179
View File
@@ -0,0 +1,179 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
if is_accelerate_available():
pass
logger = logging.get_logger(__name__)
class FluxTransformer2DLoadersMixin:
"""
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 0
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
num_image_text_embed = 4
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
num_image_text_embed = 16
# IP-Adapter
num_image_text_embeds += [num_image_text_embed]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
dtype=self.dtype,
device=self.device,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = self.device
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
class SD3Transformer2DLoadersMixin:
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
Args:
state_dict (`Dict`):
State dict with keys "ip_adapter", which contains parameters for attention processors, and
"image_proj", which contains parameters for image projection net.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# IP-Adapter cross attention parameters
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
# Dict where key is transformer layer index, value is attention processor's state dict
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
for key, weights in state_dict["ip_adapter"].items():
idx, name = key.split(".", maxsplit=1)
layer_state_dict[int(idx)][name] = weights
# Create IP-Adapter attention processor
attn_procs = {}
for idx, name in enumerate(self.attn_processors.keys()):
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size,
ip_hidden_states_dim=ip_hidden_states_dim,
head_dim=self.config.attention_head_dim,
timesteps_emb_dim=timesteps_emb_dim,
).to(self.device, dtype=self.dtype)
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
)
self.set_attn_processor(attn_procs)
# Image projetion parameters
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
num_queries = state_dict["image_proj"]["latents"].shape[1]
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
# Image projection
self.image_proj = IPAdapterTimeImageProjection(
embed_dim=embed_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
heads=heads,
num_queries=num_queries,
timestep_in_dim=timestep_in_dim,
).to(device=self.device, dtype=self.dtype)
if not low_cpu_mem_usage:
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
else:
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
+10
View File
@@ -31,6 +31,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
@@ -59,12 +61,15 @@ if is_torch_available():
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -94,6 +99,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
@@ -126,11 +133,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DualTransformer2DModel,
FluxTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
SD3Transformer2DModel,
StableAudioDiTModel,
T5FilmDecoder,
+21 -9
View File
@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available
from ..utils.import_utils import is_torch_npu_available, is_torch_version
if is_torch_npu_available():
@@ -79,10 +79,10 @@ class GELU(nn.Module):
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
@@ -164,3 +164,15 @@ class ApproximateGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class LinearActivation(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.activation = get_activation(activation)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
return self.activation(hidden_states)
+13 -4
View File
@@ -19,7 +19,7 @@ from torch import nn
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
self._chunk_dim = dim
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
@@ -1222,6 +1229,8 @@ class FeedForward(nn.Module):
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
elif activation_fn == "linear-silu":
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
self.net = nn.ModuleList([])
# project in
+1 -1
View File
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)
hidden_states = hidden_states.transpose(1, 0, 2)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
else:
# compute attentions
if self.split_head_dim:
+688 -92
View File
@@ -199,12 +199,16 @@ class Attention(nn.Module):
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "layer_norm_across_heads":
# Lumina applys qk norm across all heads
# Lumina applies qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "l2":
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
@@ -250,14 +254,22 @@ class Attention(nn.Module):
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
else:
self.to_add_out = None
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
@@ -563,7 +575,7 @@ class Attention(nn.Module):
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
@@ -778,7 +790,11 @@ class Attention(nn.Module):
self.to_kv.bias.copy_(concatenated_bias)
# handle added projections for SD3 and others.
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
@@ -890,6 +906,177 @@ class SanaMultiscaleLinearAttention(nn.Module):
return self.processor(self, hidden_states)
class MochiAttention(nn.Module):
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
processor: "MochiAttnProcessor2_0",
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_proj_bias: bool = True,
out_dim: Optional[int] = None,
out_context_dim: Optional[int] = None,
out_bias: bool = True,
context_pre_only: bool = False,
eps: float = 1e-5,
):
super().__init__()
from .normalization import MochiRMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim else query_dim
self.context_pre_only = context_pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.norm_q = MochiRMSNorm(dim_head, eps, True)
self.norm_k = MochiRMSNorm(dim_head, eps, True)
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
if not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**kwargs,
)
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "MochiAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
total_length = sequence_length + encoder_sequence_length
batch_size, heads, _, dim = query.shape
attn_outputs = []
for idx in range(batch_size):
mask = attention_mask[idx][None, :]
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
attn_output = F.scaled_dot_product_attention(
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
)
valid_sequence_length = attn_output.size(2)
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
attn_outputs.append(attn_output)
hidden_states = torch.cat(attn_outputs, dim=0)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
@@ -2466,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU:
return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3852,94 +4182,6 @@ class LuminaAttnProcessor2_0:
return hidden_states
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -5144,6 +5386,177 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
return hidden_states
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
additional image-based information and timestep embeddings.
Args:
hidden_size (`int`):
The number of hidden channels.
ip_hidden_states_dim (`int`):
The image feature dimension.
head_dim (`int`):
The number of head channels.
timesteps_emb_dim (`int`, defaults to 1280):
The number of input channels for timestep embedding.
scale (`float`, defaults to 0.5):
IP-Adapter scale.
"""
def __init__(
self,
hidden_size: int,
ip_hidden_states_dim: int,
head_dim: int,
timesteps_emb_dim: int = 1280,
scale: float = 0.5,
):
super().__init__()
# To prevent circular import
from .normalization import AdaLayerNorm, RMSNorm
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.norm_q = RMSNorm(head_dim, 1e-6)
self.norm_k = RMSNorm(head_dim, 1e-6)
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
self.scale = scale
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
ip_hidden_states: torch.FloatTensor = None,
temb: torch.FloatTensor = None,
) -> torch.FloatTensor:
"""
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
Args:
attn (`Attention`):
Attention instance.
hidden_states (`torch.FloatTensor`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
The encoder hidden states.
attention_mask (`torch.FloatTensor`, *optional*):
Attention mask.
ip_hidden_states (`torch.FloatTensor`, *optional*):
Image embeddings.
temb (`torch.FloatTensor`, *optional*):
Timestep embeddings.
Returns:
`torch.FloatTensor`: Output hidden states.
"""
residual = hidden_states
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
img_query = query
img_key = key
img_value = value
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP Adapter
if self.scale != 0 and ip_hidden_states is not None:
# Norm image features
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
# To k and v
ip_key = self.to_k_ip(norm_ip_hidden_states)
ip_value = self.to_v_ip(norm_ip_hidden_states)
# Reshape
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Norm
query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
ip_key = self.norm_ip_k(ip_key)
# cat img
key = torch.cat([img_key, ip_key], dim=2)
value = torch.cat([img_value, ip_value], dim=2)
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + ip_hidden_states * self.scale
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -5407,21 +5820,37 @@ class SanaMultiscaleAttnProcessor2_0:
class LoRAAttnProcessor:
r"""
Processor for implementing attention with LoRA.
"""
def __init__(self):
pass
class LoRAAttnProcessor2_0:
r"""
Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
pass
class LoRAXFormersAttnProcessor:
r"""
Processor for implementing attention with LoRA using xFormers.
"""
def __init__(self):
pass
class LoRAAttnAddedKVProcessor:
r"""
Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
"""
def __init__(self):
pass
@@ -5437,6 +5866,165 @@ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
super().__init__()
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
query = F.relu(query)
key = F.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = hidden_states.to(original_dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if original_dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class PAGCFGSanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
query = F.relu(query)
key = F.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states_org = torch.matmul(scores, query)
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
hidden_states_org = hidden_states_org.to(original_dtype)
hidden_states_org = attn.to_out[0](hidden_states_org)
hidden_states_org = attn.to_out[1](hidden_states_org)
# perturbed path (identity attention)
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if original_dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class PAGIdentitySanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
query = F.relu(query)
key = F.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states_org = torch.matmul(scores, query)
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
hidden_states_org = hidden_states_org.float()
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
hidden_states_org = hidden_states_org.to(original_dtype)
hidden_states_org = attn.to_out[0](hidden_states_org)
hidden_states_org = attn.to_out[1](hidden_states_org)
# perturbed path (identity attention)
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if original_dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
@@ -5451,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = (
SlicedAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
)
AttentionProcessor = Union[
@@ -5477,21 +6066,28 @@ AttentionProcessor = Union[
AttnProcessorNPU,
AttnProcessor2_0,
MochiVaeAttnProcessor2_0,
MochiAttnProcessor2_0,
StableAudioAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FusedHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
MochiAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
SlicedAttnProcessor,
SlicedAttnAddedKVProcessor,
SanaLinearAttnProcessor2_0,
PAGCFGSanaLinearAttnProcessor2_0,
PAGIdentitySanaLinearAttnProcessor2_0,
SanaMultiscaleLinearAttention,
SanaMultiscaleAttnProcessor2_0,
SanaMultiscaleAttentionProjection,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
LoRAAttnProcessor,
@@ -3,6 +3,8 @@ from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
@@ -26,39 +26,10 @@ from ..activations import get_activation
from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import DecoderOutput, EncoderOutput
class GLUMBConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
hidden_channels = 4 * in_channels
self.nonlinearity = nn.SiLU()
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
return hidden_states + residual
class ResBlock(nn.Module):
def __init__(
self,
@@ -115,6 +86,7 @@ class EfficientViTBlock(nn.Module):
self.conv_out = GLUMBConv(
in_channels=in_channels,
out_channels=in_channels,
norm_type="rms_norm",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -15,7 +15,7 @@ if is_torch_available():
SparseControlNetModel,
SparseControlNetOutput,
)
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
from .controlnet_union import ControlNetUnionModel
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
from .multicontrolnet import MultiControlNetModel
@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...image_processor import PipelineImageInput
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ..attention_processor import (
@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
@dataclass
class ControlNetUnionInput:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
"""
openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
def __len__(self) -> int:
return len(vars(self))
def __iter__(self):
return iter(vars(self))
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
@dataclass
class ControlNetUnionInputProMax:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
- 6: tile
- 7: repaint
"""
openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
tile: Optional[PipelineImageInput] = None
repaint: Optional[PipelineImageInput] = None
def __len__(self) -> int:
return len(vars(self))
def __iter__(self):
return iter(vars(self))
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
setattr(self, key, value)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
controlnet_cond (`List[torch.Tensor]`):
The conditional input tensors.
control_type (`torch.Tensor`):
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
type is used.
control_type_idx (`List[int]`):
The indices of `control_type`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(controlnet_cond) != self.config.num_control_type:
if isinstance(controlnet_cond, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
)
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = []
condition_list = []
for idx, image_type in enumerate(controlnet_cond):
if controlnet_cond[image_type] is None:
continue
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
for cond, control_idx in zip(controlnet_cond, control_type_idx):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[idx]
feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
+477 -18
View File
@@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed(
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
device: Optional[torch.device] = None,
output_type: str = "np",
) -> torch.Tensor:
r"""
Creates 3D sinusoidal positional embeddings.
Args:
embed_dim (`int`):
The embedding dimension of inputs. It must be divisible by 16.
spatial_size (`int` or `Tuple[int, int]`):
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width).
temporal_size (`int`):
The temporal dimension of postional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0):
Scale factor for temporal grid interpolation.
Returns:
`torch.Tensor`:
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
embed_dim]`.
"""
if output_type == "np":
return _get_3d_sincos_pos_embed_np(
embed_dim=embed_dim,
spatial_size=spatial_size,
temporal_size=temporal_size,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
)
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4
# 1. Spatial
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
# 2. Temporal
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
# 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
spatial_size[0] * spatial_size[1], dim=1
) # [T, H*W, D // 4]
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
return pos_embed
def _get_3d_sincos_pos_embed_np(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Creates 3D sinusoidal positional embeddings.
@@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed(
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
embed_dim]`.
"""
deprecation_message = (
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
@@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed(
def get_2d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token=False,
extra_tokens=0,
interpolation_scale=1.0,
base_size=16,
device: Optional[torch.device] = None,
output_type: str = "np",
):
"""
Creates 2D sinusoidal positional embeddings.
Args:
embed_dim (`int`):
The embedding dimension.
grid_size (`int`):
The size of the grid height and width.
cls_token (`bool`, defaults to `False`):
Whether or not to add a classification token.
extra_tokens (`int`, defaults to `0`):
The number of extra tokens to add.
interpolation_scale (`float`, defaults to `1.0`):
The scale of the interpolation.
Returns:
pos_embed (`torch.Tensor`):
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
embed_dim]` if using cls_token
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_2d_sincos_pos_embed_np(
embed_dim=embed_dim,
grid_size=grid_size,
cls_token=cls_token,
extra_tokens=extra_tokens,
interpolation_scale=interpolation_scale,
base_size=base_size,
)
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = (
torch.arange(grid_size[0], device=device, dtype=torch.float32)
/ (grid_size[0] / base_size)
/ interpolation_scale
)
grid_w = (
torch.arange(grid_size[1], device=device, dtype=torch.float32)
/ (grid_size[1] / base_size)
/ interpolation_scale
)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
if cls_token and extra_tokens > 0:
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
r"""
This function generates 2D sinusoidal positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension.
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
Returns:
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_2d_sincos_pos_embed_from_grid_np(
embed_dim=embed_dim,
grid=grid,
)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
"""
if output_type == "np":
deprecation_message = (
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.outer(pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_2d_sincos_pos_embed_np(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
@@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed(
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
r"""
This function generates 2D sinusoidal positional embeddings from a grid.
@@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
"""
This function generates 1D positional embeddings from a grid.
@@ -288,10 +503,14 @@ class PatchEmbed(nn.Module):
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
embed_dim,
grid_size,
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
output_type="pt",
)
persistent = True if pos_embed_max_size else False
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
@@ -323,7 +542,6 @@ class PatchEmbed(nn.Module):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
@@ -341,8 +559,10 @@ class PatchEmbed(nn.Module):
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
device=latent.device,
output_type="pt",
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
pos_embed = pos_embed.float().unsqueeze(0)
else:
pos_embed = self.pos_embed
@@ -453,7 +673,9 @@ class CogVideoXPatchEmbed(nn.Module):
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
def _get_positional_embeddings(
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
@@ -465,9 +687,11 @@ class CogVideoXPatchEmbed(nn.Module):
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=device,
output_type="pt",
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
@@ -521,8 +745,10 @@ class CogVideoXPatchEmbed(nn.Module):
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
pos_embedding = self._get_positional_embeddings(
height, width, pre_time_compression_frames, device=embeds.device
)
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
@@ -552,9 +778,11 @@ class CogView3PlusPatchEmbed(nn.Module):
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = get_2d_sincos_pos_embed(
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
@@ -729,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
def get_2d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
device: (`torch.device`, **optional**):
The device used to create tensors.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return _get_2d_rotary_pos_embed_np(
embed_dim=embed_dim,
crops_coords=crops_coords,
grid_size=grid_size,
use_real=use_real,
)
start, stop = crops_coords
# scale end by (steps1)/steps matches np.linspace(..., endpoint=False)
grid_h = torch.linspace(
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
)
grid_w = torch.linspace(
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
@@ -1257,7 +1535,7 @@ class ImageProjection(nn.Module):
batch_size = image_embeds.shape[0]
# image
image_embeds = self.image_embeds(image_embeds)
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
@@ -2118,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
return out
class IPAdapterTimeImageProjectionBlock(nn.Module):
"""Block for IPAdapterTimeImageProjection.
Args:
hidden_dim (`int`, defaults to 1280):
The number of hidden channels.
dim_head (`int`, defaults to 64):
The number of head channels.
heads (`int`, defaults to 20):
Parallel attention heads.
ffn_ratio (`int`, defaults to 4):
The expansion ratio of feedforward network hidden layer channels.
"""
def __init__(
self,
hidden_dim: int = 1280,
dim_head: int = 64,
heads: int = 20,
ffn_ratio: int = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
self.attn = Attention(
query_dim=hidden_dim,
cross_attention_dim=hidden_dim,
dim_head=dim_head,
heads=heads,
bias=False,
out_bias=False,
)
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
# AdaLayerNorm
self.adaln_silu = nn.SiLU()
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
self.adaln_norm = nn.LayerNorm(hidden_dim)
# Set attention scale and fuse KV
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
self.attn.fuse_projections()
self.attn.to_k = None
self.attn.to_v = None
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x (`torch.Tensor`):
Image features.
latents (`torch.Tensor`):
Latent features.
timestep_emb (`torch.Tensor`):
Timestep embedding.
Returns:
`torch.Tensor`: Output latent features.
"""
# Shift and scale for AdaLayerNorm
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
# Fused Attention
residual = latents
x = self.ln0(x)
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
batch_size = latents.shape[0]
query = self.attn.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.attn.heads
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
latents = weight @ value
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
latents = self.attn.to_out[0](latents)
latents = self.attn.to_out[1](latents)
latents = latents + residual
## FeedForward
residual = latents
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
return self.ff(latents) + residual
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
class IPAdapterTimeImageProjection(nn.Module):
"""Resampler of SD3 IP-Adapter with timestep embedding.
Args:
embed_dim (`int`, defaults to 1152):
The feature dimension.
output_dim (`int`, defaults to 2432):
The number of output channels.
hidden_dim (`int`, defaults to 1280):
The number of hidden channels.
depth (`int`, defaults to 4):
The number of blocks.
dim_head (`int`, defaults to 64):
The number of head channels.
heads (`int`, defaults to 20):
Parallel attention heads.
num_queries (`int`, defaults to 64):
The number of queries.
ffn_ratio (`int`, defaults to 4):
The expansion ratio of feedforward network hidden layer channels.
timestep_in_dim (`int`, defaults to 320):
The number of input channels for timestep embedding.
timestep_flip_sin_to_cos (`bool`, defaults to True):
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
timestep_freq_shift (`int`, defaults to 0):
Controls the timestep delta between frequencies between dimensions.
"""
def __init__(
self,
embed_dim: int = 1152,
output_dim: int = 2432,
hidden_dim: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 20,
num_queries: int = 64,
ffn_ratio: int = 4,
timestep_in_dim: int = 320,
timestep_flip_sin_to_cos: bool = True,
timestep_freq_shift: int = 0,
) -> None:
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
self.proj_in = nn.Linear(embed_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList(
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
)
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Args:
x (`torch.Tensor`):
Image features.
timestep (`torch.Tensor`):
Timestep in denoising process.
Returns:
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
"""
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
timestep_emb = self.time_embedding(timestep_emb)
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
x = x + timestep_emb[:, None]
for block in self.layers:
latents = block(x, latents, timestep_emb)
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents, timestep_emb
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
+85 -5
View File
@@ -17,6 +17,7 @@
import importlib
import inspect
import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
@@ -25,8 +26,8 @@ import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
@@ -34,6 +35,8 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_gguf_available,
is_torch_available,
is_torch_version,
logging,
)
@@ -140,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
@@ -182,7 +187,6 @@ def load_model_dict_into_meta(
device = device or torch.device("cpu")
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
empty_state_dict = model.state_dict()
@@ -213,14 +217,15 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype
# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
@@ -398,3 +403,78 @@ def _fetch_index_file_legacy(
index_file = None
return index_file
def _gguf_parse_value(_value, data_type):
if not isinstance(data_type, list):
data_type = [data_type]
if len(data_type) == 1:
data_type = data_type[0]
array_data_type = None
else:
if data_type[0] != 9:
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
data_type, array_data_type = data_type
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
_value = int(_value[0])
elif data_type in [6, 12]:
_value = float(_value[0])
elif data_type in [7]:
_value = bool(_value[0])
elif data_type in [8]:
_value = array("B", list(_value)).tobytes().decode()
elif data_type in [9]:
_value = _gguf_parse_value(_value, array_data_type)
return _value
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
"""
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
attributes.
Args:
gguf_checkpoint_path (`str`):
The path the to GGUF file to load
return_tensors (`bool`, defaults to `True`):
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
metadata in memory.
"""
if is_gguf_available() and is_torch_available():
import gguf
from gguf import GGUFReader
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
else:
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
reader = GGUFReader(gguf_checkpoint_path)
parsed_parameters = {}
for tensor in reader.tensors:
name = tensor.name
quant_type = tensor.tensor_type
# if the tensor is a torch supported dtype do not use GGUFParameter
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
raise ValueError(
(
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
"\n\nCurrently the following quantization types are supported: \n\n"
f"{_supported_quants_str}"
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
)
)
weights = torch.from_numpy(tensor.data.copy())
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
+41 -24
View File
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try:
return next(parameter.parameters()).dtype
except StopIteration:
try:
return next(parameter.buffers()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for param in parameter.parameters():
last_dtype = param.dtype
if param.is_floating_point():
return param.dtype
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
for buffer in parameter.buffers():
last_dtype = buffer.dtype
if buffer.is_floating_point():
return buffer.dtype
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -700,10 +718,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None
if hf_quantizer is not None:
if device_map is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
@@ -800,7 +820,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None:
if hf_quantizer is not None and is_bnb_quantization_method:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
@@ -858,13 +878,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if device_map is None and not is_sharded:
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
# It would error out during the `validate_environment()` call above in the absence of cuda.
is_quant_method_bnb = (
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
if hf_quantizer is None:
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
@@ -1039,14 +1056,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dtype_present_in_args = True
break
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_quantized", False):
if dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
)
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
+30 -27
View File
@@ -234,33 +234,6 @@ class LuminaRMSNormZero(nn.Module):
return x, gate_msa, scale_mlp, gate_mlp
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
return hidden_states, gate_msa, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
@@ -549,6 +522,36 @@ class RMSNorm(nn.Module):
return hidden_states
# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
class MochiRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
@@ -11,12 +11,15 @@ if is_torch_available():
from .lumina_nextdit2d import LuminaNextDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .sana_transformer import SanaTransformer2DModel
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -156,9 +156,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
# define temporal positional embedding
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
inner_dim, torch.arange(0, video_length).unsqueeze(1)
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
) # 1152 hidden size
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
self.gradient_checkpointing = False
@@ -0,0 +1,487 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 4,
norm_type: Optional[str] = None,
residual_connection: bool = True,
) -> None:
super().__init__()
hidden_channels = int(expand_ratio * in_channels)
self.norm_type = norm_type
self.residual_connection = residual_connection
self.nonlinearity = nn.SiLU()
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = None
if norm_type == "rms_norm":
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.residual_connection:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
"""
def __init__(
self,
dim: int = 2240,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
dropout: float = 0.0,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
attention_bias: bool = True,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
attention_out_bias: bool = True,
mlp_ratio: float = 2.5,
) -> None:
super().__init__()
# 1. Self Attention
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=SanaLinearAttnProcessor2_0(),
)
# 2. Cross Attention
if cross_attention_dim is not None:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_cross_attention_heads,
dim_head=cross_attention_head_dim,
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=AttnProcessor2_0(),
)
# 3. Feed-forward
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
height: int = None,
width: int = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
attn_output = self.attn1(norm_hidden_states)
hidden_states = hidden_states + gate_msa * attn_output
# 3. Cross Attention
if self.attn2 is not None:
attn_output = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
ff_output = self.ff(norm_hidden_states)
ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
hidden_states = hidden_states + gate_mlp * ff_output
return hidden_states
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
Args:
in_channels (`int`, defaults to `32`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `32`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `70`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `32`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of Transformer blocks to use.
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
The number of heads to use for cross-attention.
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
The number of channels in each head for cross-attention.
cross_attention_dim (`int`, *optional*, defaults to `2240`):
The number of channels in the cross-attention output.
caption_channels (`int`, defaults to `2304`):
The number of channels in the caption embeddings.
mlp_ratio (`float`, defaults to `2.5`):
The expansion ratio to use in the GLUMBConv layer.
dropout (`float`, defaults to `0.0`):
The dropout probability.
attention_bias (`bool`, defaults to `False`):
Whether to use bias in the attention layer.
sample_size (`int`, defaults to `32`):
The base size of the input latent.
patch_size (`int`, defaults to `1`):
The size of the patches to use in the patch embedding layer.
norm_elementwise_affine (`bool`, defaults to `False`):
Whether to use elementwise affinity in the normalization layer.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value for the normalization layer.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
@register_to_config
def __init__(
self,
in_channels: int = 32,
out_channels: Optional[int] = 32,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
num_layers: int = 20,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
SanaTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
# 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size, num_channels, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
hidden_states = self.patch_embed(hidden_states)
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
# 2. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
**ckpt_kwargs,
)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
)
# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
)
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -21,7 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attn_output, context_attn_output = self.attn(
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
):
"""
The Transformer model introduced in Flux.
@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -524,7 +538,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
@@ -0,0 +1,787 @@
# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideoAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
query[:, :, -encoder_hidden_states.shape[1] :],
],
dim=2,
)
key = torch.cat(
[
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
key[:, :, -encoder_hidden_states.shape[1] :],
],
dim=2,
)
else:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class HunyuanVideoPatchEmbed(nn.Module):
def __init__(
self,
patch_size: Union[int, Tuple[int, int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
return hidden_states
class HunyuanVideoAdaNorm(nn.Module):
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
self.linear = nn.Linear(in_features, out_features)
self.nonlinearity = nn.SiLU()
def forward(
self, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
return gate_msa, gate_mlp
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
)
gate_msa, gate_mlp = self.norm_out(temb)
hidden_states = hidden_states + attn_output * gate_msa
ff_output = self.ff(self.norm2(hidden_states))
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
class HunyuanVideoIndividualTokenRefiner(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
self.refiner_blocks = nn.ModuleList(
[
HunyuanVideoIndividualTokenRefinerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
batch_size = attention_mask.shape[0]
seq_len = attention_mask.shape[1]
attention_mask = attention_mask.to(hidden_states.device).bool()
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True
for block in self.refiner_blocks:
hidden_states = block(hidden_states, temb, self_attn_mask)
return hidden_states
class HunyuanVideoTokenRefiner(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=hidden_size, pooled_projection_dim=in_channels
)
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=num_layers,
mlp_width_ratio=mlp_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_projections = hidden_states.mean(dim=1)
else:
original_dtype = hidden_states.dtype
mask_float = attention_mask.float().unsqueeze(-1)
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
pooled_projections = pooled_projections.to(original_dtype)
temb = self.time_text_embed(timestep, pooled_projections)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
return hidden_states
class HunyuanVideoRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.rope_dim = rope_dim
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
axes_grids = []
for i in range(3):
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
# original implementation creates it on CPU and then moves it to device. This results in numerical
# differences in layerwise debugging outputs, but visually it is the same.
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
axes_grids.append(grid)
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
grid = torch.stack(grid, dim=0) # [3, W, H, T]
freqs = []
for i in range(3):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
class HunyuanVideoSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=freqs_cis,
)
# 3. Modulation and residual connection
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: int = 2,
patch_size_t: int = 1,
qk_norm: str = "rms_norm",
guidance_embeds: bool = True,
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
# 3. Attention mask preparation
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
@@ -0,0 +1,469 @@
# Copyright 2024 The Genmo team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTXRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
base_num_frames: int = 20,
base_height: int = 2048,
base_width: int = 2048,
patch_size: int = 1,
patch_size_t: int = 1,
theta: float = 10000.0,
) -> None:
super().__init__()
self.dim = dim
self.base_num_frames = base_num_frames
self.base_height = base_height
self.base_width = base_width
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.theta = theta
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
# Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if rope_interpolation_scale is not None:
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
grid = grid.flatten(2, 4).transpose(1, 2)
start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
math.log(start, self.theta),
math.log(end, self.theta),
self.dim // 6,
device=hidden_states.device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
freqs = freqs.transpose(-1, -2).flatten(2)
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % 6 != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
return cos_freqs, sin_freqs
@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
attention_out_bias: bool = True,
eps: float = 1e-6,
elementwise_affine: bool = False,
):
super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
attn_hidden_states = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
@maybe_allow_in_graph
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
Args:
in_channels (`int`, defaults to `128`):
The number of channels in the input.
out_channels (`int`, defaults to `128`):
The number of channels in the output.
patch_size (`int`, defaults to `1`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
cross_attention_dim (`int`, defaults to `2048 `):
The number of channels for cross attention heads.
num_layers (`int`, defaults to `28`):
The number of layers of Transformer blocks to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
The normalization layer to use.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
patch_size_t: int = 1,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
cross_attention_dim: int = 2048,
num_layers: int = 28,
activation_fn: str = "gelu-approximate",
qk_norm: str = "rms_norm_across_heads",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 4096,
attention_bias: bool = True,
attention_out_bias: bool = True,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
self.proj_in = nn.Linear(in_channels, inner_dim)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.rope = LTXRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
base_width=2048,
patch_size=patch_size,
patch_size_t=patch_size_t,
theta=10000.0,
)
self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
qk_norm=qk_norm,
activation_fn=activation_fn,
attention_bias=attention_bias,
attention_out_bias=attention_out_bias,
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def apply_rotary_emb(x, freqs):
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
@@ -20,19 +20,100 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
from ..normalization import AdaLayerNormContinuous, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MochiModulatedRMSNorm(nn.Module):
def __init__(self, eps: float):
super().__init__()
self.eps = eps
self.norm = RMSNorm(0, eps, False)
def forward(self, hidden_states, scale=None):
hidden_states_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.norm(hidden_states)
if scale is not None:
hidden_states = hidden_states * scale
hidden_states = hidden_states.to(hidden_states_dtype)
return hidden_states
class MochiLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
eps=1e-5,
bias=True,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
self.norm = MochiModulatedRMSNorm(eps=eps)
def forward(
self,
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) -> torch.Tensor:
input_dtype = x.dtype
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
return x.to(input_dtype)
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(0, eps, False)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states_dtype = hidden_states.dtype
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
hidden_states = hidden_states.to(hidden_states_dtype)
return hidden_states, gate_msa, scale_mlp, gate_mlp
@maybe_allow_in_graph
class MochiTransformerBlock(nn.Module):
r"""
@@ -77,38 +158,32 @@ class MochiTransformerBlock(nn.Module):
if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
else:
self.norm1_context = LuminaLayerNormContinuous(
self.norm1_context = MochiLayerNormContinuous(
embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim,
eps=eps,
elementwise_affine=False,
norm_type="rms_norm",
out_dim=None,
)
self.attn1 = Attention(
self.attn1 = MochiAttention(
query_dim=dim,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=False,
qk_norm=qk_norm,
added_kv_proj_dim=pooled_projection_dim,
added_proj_bias=False,
out_dim=dim,
out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(),
eps=eps,
elementwise_affine=True,
eps=1e-5,
)
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm2 = MochiModulatedRMSNorm(eps=eps)
self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = MochiModulatedRMSNorm(eps)
self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = None
@@ -120,14 +195,15 @@ class MochiTransformerBlock(nn.Module):
bias=False,
)
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm4 = MochiModulatedRMSNorm(eps=eps)
self.norm4_context = MochiModulatedRMSNorm(eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
@@ -143,22 +219,25 @@ class MochiTransformerBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states
) * torch.tanh(enc_gate_msa).unsqueeze(1)
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
)
norm_encoder_hidden_states = self.norm3_context(
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
enc_gate_mlp
).unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
)
return hidden_states, encoder_hidden_states
@@ -203,7 +282,10 @@ class MochiRoPE(nn.Module):
return positions
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
with torch.autocast(freqs.device.type, torch.float32):
# Always run ROPE freqs computation in FP32
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
@@ -223,7 +305,7 @@ class MochiRoPE(nn.Module):
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
@@ -253,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["MochiTransformerBlock"]
@register_to_config
def __init__(
@@ -309,7 +392,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
inner_dim,
inner_dim,
elementwise_affine=False,
eps=1e-6,
norm_type="layer_norm",
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
@@ -350,7 +437,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
timestep,
encoder_hidden_states,
encoder_attention_mask,
hidden_dtype=hidden_states.dtype,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -381,6 +471,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
@@ -389,9 +480,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
@@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
@@ -103,7 +103,9 @@ class SD3SingleTransformerBlock(nn.Module):
return hidden_states
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
"""
The Transformer model introduced in Stable Diffusion 3.
@@ -349,8 +351,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states (`list` of `torch.Tensor`):
@@ -390,6 +392,12 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False
@@ -411,11 +419,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
+1 -1
View File
@@ -217,7 +217,7 @@ class MidResTemporalBlock1D(nn.Module):
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
hidden_states = self.downsample(hidden_states)
return hidden_states
+8 -1
View File
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
conditioning with `class_embed_type` equal to `None`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.Tensor,
+75 -8
View File
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(hidden_states)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
+4 -2
View File
@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet in self.resnets:
# pop res hidden states
@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -382,6 +382,20 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
@@ -457,15 +471,23 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
else:
@@ -473,6 +495,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
+10
View File
@@ -128,6 +128,7 @@ else:
]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
"FluxControlImg2ImgPipeline",
"FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -185,6 +186,7 @@ else:
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"PixArtSigmaPAGPipeline",
"SanaPAGPipeline",
]
)
_import_structure["controlnet_xs"].extend(
@@ -213,6 +215,7 @@ else:
"IFSuperResolutionPipeline",
]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -250,6 +253,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["marigold"].extend(
[
@@ -262,6 +266,7 @@ else:
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["sana"] = ["SanaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
@@ -535,6 +540,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
@@ -546,6 +552,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .hunyuan_video import HunyuanVideoPipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
@@ -585,6 +592,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaText2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
@@ -597,6 +605,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline,
KolorsPAGPipeline,
PixArtSigmaPAGPipeline,
SanaPAGPipeline,
StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
@@ -613,6 +622,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .sana import SanaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
@@ -59,6 +59,7 @@ EXAMPLE_DOC_STRING = """
>>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32)
>>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda")
>>> pipe.enable_vae_tiling()
>>> prompt = (
... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
@@ -636,6 +637,35 @@ class AllegroPipeline(DiffusionPipeline):
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
@property
def guidance_scale(self):
return self._guidance_scale
+49 -6
View File
@@ -18,6 +18,7 @@ from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
@@ -28,12 +29,18 @@ from .controlnet import (
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetUnionImg2ImgPipeline,
StableDiffusionXLControlNetUnionInpaintPipeline,
StableDiffusionXLControlNetUnionPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
@@ -108,6 +115,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
("cascade", StableCascadeCombinedPipeline),
("lcm", LatentConsistencyModelPipeline),
@@ -120,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
@@ -139,11 +148,13 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
]
)
@@ -158,9 +169,11 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
]
)
@@ -394,13 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
if "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
else:
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -684,16 +704,28 @@ class AutoPipelineForImage2Image(ConfigMixin):
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "Img2Img" in orig_class_name:
to_replace = "Img2ImgPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
else:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -981,15 +1013,26 @@ class AutoPipelineForInpainting(ConfigMixin):
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "Inpaint" in orig_class_name:
to_replace = "InpaintPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
else:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -38,7 +38,7 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import CogView3PlusPipeline
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16)
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars"
@@ -31,6 +31,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -42,6 +43,13 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1323,6 +1331,8 @@ class StableDiffusionControlNetPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
@@ -40,7 +40,6 @@ from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
import numpy as np
@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
repaint=controlnet_img,
)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0]
image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0]
image.save("inpaint.png")
```
"""
@@ -210,11 +205,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
"mask",
"masked_image_latents",
]
@@ -1130,7 +1122,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
@@ -1158,6 +1150,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
@@ -1345,20 +1338,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@@ -1375,36 +1354,44 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = []
for image_type in control_image_list:
if control_image_list[image_type]:
self.check_inputs(
prompt,
prompt_2,
control_image_list[image_type],
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
control_type.append(1)
else:
control_type.append(0)
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
control_type = torch.Tensor(control_type)
@@ -1499,23 +1486,21 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images
for image_type in control_image_list:
if control_image_list[image_type]:
control_image = self.prepare_control_image(
image=control_image_list[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[image_type] = control_image
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
@@ -1589,6 +1574,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
# 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
@@ -1693,8 +1681,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """
>>> # !pip install controlnet_aux
>>> from controlnet_aux import LineartAnimeDetector
>>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
>>> from diffusers.models.controlnets import ControlNetUnionInput
>>> from diffusers.utils import load_image
>>> import torch
@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """
... controlnet=controlnet,
... vae=vae,
... torch_dtype=torch.float16,
... variant="fp16",
... )
>>> pipe.enable_model_cpu_offload()
>>> # prepare image
>>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
>>> controlnet_img = processor(image, output_type="pil")
>>> # set ControlNetUnion input
>>> union_input = ControlNetUnionInput(
... canny=controlnet_img,
... )
>>> # generate image
>>> image = pipe(prompt, image=union_input).images[0]
>>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0]
```
"""
@@ -226,12 +221,8 @@ class StableDiffusionXLControlNetUnionPipeline(
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
"image",
]
def __init__(
@@ -791,26 +782,6 @@ class StableDiffusionXLControlNetUnionPipeline(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def check_input(
self,
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
):
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(image, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(image) != controlnet.config.num_control_type:
if isinstance(image, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(image, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInput`."
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_image(
self,
@@ -970,7 +941,7 @@ class StableDiffusionXLControlNetUnionPipeline(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
@@ -997,6 +968,7 @@ class StableDiffusionXLControlNetUnionPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
@@ -1018,10 +990,7 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
control_image (`PipelineImageInput`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
@@ -1168,38 +1137,45 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
self.check_input(image)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
# 1. Check inputs. Raise error if not correct
control_type = []
for image_type in image:
if image[image_type]:
self.check_inputs(
prompt,
prompt_2,
image[image_type],
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type.append(1)
else:
control_type.append(0)
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type = torch.Tensor(control_type)
@@ -1258,20 +1234,19 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 4. Prepare image
for image_type in image:
if image[image_type]:
image[image_type] = self.prepare_image(
image=image[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image[image_type].shape[-2:]
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
@@ -1312,11 +1287,11 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 7.2 Prepare added time ids & embeddings
for image_type in image:
if isinstance(image[image_type], torch.Tensor):
original_size = original_size or image[image_type].shape[-2:]
original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
@@ -1424,8 +1399,9 @@ class StableDiffusionXLControlNetUnionPipeline(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
@@ -1471,14 +1447,8 @@ class StableDiffusionXLControlNetUnionPipeline(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
image = callback_outputs.pop("image", image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """
ControlNetUnionModel,
AutoencoderKL,
)
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
from PIL import Image
@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# `enable_model_cpu_offload` is not recommended due to multiple generations
height = image.height
@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """
# set ControlNetUnion input
result_images = []
for sub_img, crops_coords in zip(images, crops_coords_list):
union_input = ControlNetUnionInputProMax(
tile=sub_img,
)
new_width, new_height = W, H
out = pipe(
prompt=[prompt] * 1,
image=sub_img,
control_image_list=union_input,
control_image=[sub_img],
control_mode=[6],
width=new_width,
height=new_height,
num_inference_steps=30,
@@ -247,11 +244,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
]
def __init__(
@@ -1065,7 +1059,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
@@ -1090,6 +1084,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
@@ -1119,10 +1114,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`)::
control_image (`PipelineImageInput`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
@@ -1291,53 +1283,47 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
# 1. Check inputs. Raise error if not correct
control_type = []
for image_type in control_image_list:
if control_image_list[image_type]:
self.check_inputs(
prompt,
prompt_2,
control_image_list[image_type],
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type.append(1)
else:
control_type.append(0)
if not isinstance(control_image, list):
control_image = [control_image]
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
num_control_type = controlnet.config.num_control_type
# 1. Check inputs
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
control_type = torch.Tensor(control_type)
@@ -1397,21 +1383,19 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
for image_type in control_image_list:
if control_image_list[image_type]:
control_image = self.prepare_control_image(
image=control_image_list[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[image_type] = control_image
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1444,10 +1428,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
)
# 7.2 Prepare added time ids & embeddings
for image_type in control_image_list:
if isinstance(control_image_list[image_type], torch.Tensor):
original_size = original_size or control_image_list[image_type].shape[-2:]
original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]
if negative_original_size is None:
negative_original_size = original_size
@@ -1531,8 +1516,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
@@ -1577,13 +1563,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -925,7 +925,11 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)
style = torch.tensor([0], device=device)
@@ -66,9 +66,13 @@ EXAMPLE_DOC_STRING = """
... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.to("cuda")
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
>>> prompt = "A girl holding a sign that says InstantX"
>>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0]
>>> control_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
... )
>>> prompt = "A bird in space"
>>> image = pipe(
... prompt, control_image=control_image, height=1024, width=768, controlnet_conditioning_scale=0.7
... ).images[0]
>>> image.save("sd3.png")
```
"""
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb)
return hidden_states
+2
View File
@@ -26,6 +26,7 @@ else:
_import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
_import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux import FluxPipeline
from .pipeline_flux_control import FluxControlPipeline
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
+173 -5
View File
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,6 +149,7 @@ class FluxPipeline(
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
FluxIPAdapterMixin,
):
r"""
The Flux pipeline for text-to-image generation.
@@ -169,8 +177,8 @@ class FluxPipeline(
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
@@ -182,6 +190,8 @@ class FluxPipeline(
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
@@ -193,6 +203,8 @@ class FluxPipeline(
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -377,14 +389,60 @@ class FluxPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
image_embeds.append(single_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -419,10 +477,33 @@ class FluxPipeline(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -551,6 +632,9 @@ class FluxPipeline(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
@@ -561,6 +645,12 @@ class FluxPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -610,6 +700,17 @@ class FluxPipeline(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -647,8 +748,12 @@ class FluxPipeline(
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -670,6 +775,7 @@ class FluxPipeline(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
@@ -684,6 +790,21 @@ class FluxPipeline(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -725,12 +846,43 @@ class FluxPipeline(
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -746,6 +898,22 @@ class FluxPipeline(
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -403,7 +403,6 @@ class FluxControlPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
def check_inputs(
self,
prompt,
File diff suppressed because it is too large Load Diff
@@ -1095,7 +1095,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# predict the noise residual
if self.controlnet.config.guidance_embeds:
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds
if use_guidance:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuan_video import HunyuanVideoPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,687 @@
# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanVideoPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import export_to_video
>>> model_id = "tencent/HunyuanVideo"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
>>> pipe.vae.enable_tiling()
>>> pipe.to("cuda")
>>> output = pipe(
... prompt="A cat walks on the grass, realistic",
... height=320,
... width=512,
... num_frames=61,
... num_inference_steps=30,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=15)
```
"""
DEFAULT_PROMPT_TEMPLATE = {
"template": (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
),
"crop_start": 95,
}
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`LlamaModel`]):
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
tokenizer (`LlamaTokenizer`):
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
transformer ([`HunyuanVideoTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder_2 ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer_2 (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
text_encoder: LlamaModel,
tokenizer: LlamaTokenizerFast,
transformer: HunyuanVideoTransformer3DModel,
vae: AutoencoderKLHunyuanVideo,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder_2: CLIPTextModel,
tokenizer_2: CLIPTokenizer,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
)
self.vae_scale_factor_temporal = (
self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scale_factor_spatial = (
self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_llama_prompt_embeds(
self,
prompt: Union[str, List[str]],
prompt_template: Dict[str, Any],
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
num_hidden_layers_to_skip: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
prompt = [prompt_template["template"].format(p) for p in prompt]
crop_start = prompt_template.get("crop_start", None)
if crop_start is None:
prompt_template_input = self.tokenizer(
prompt_template["template"],
padding="max_length",
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=False,
)
crop_start = prompt_template_input["input_ids"].shape[-1]
# Remove <|eot_id|> token and placeholder {}
crop_start -= 2
max_sequence_length += crop_start
text_inputs = self.tokenizer(
prompt,
max_length=max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
)
text_input_ids = text_inputs.input_ids.to(device=device)
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
prompt_embeds = prompt_embeds.to(dtype=dtype)
if crop_start is not None and crop_start > 0:
prompt_embeds = prompt_embeds[:, crop_start:]
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 77,
) -> torch.Tensor:
device = device or self._execution_device
dtype = dtype or self.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]] = None,
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
prompt,
prompt_template,
num_videos_per_prompt,
device=device,
dtype=dtype,
max_sequence_length=max_sequence_length,
)
if pooled_prompt_embeds is None:
if prompt_2 is None and pooled_prompt_embeds is None:
prompt_2 = prompt
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt,
num_videos_per_prompt,
device=device,
dtype=dtype,
max_sequence_length=77,
)
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
prompt_template=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_template is not None:
if not isinstance(prompt_template, dict):
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
if "template" not in prompt_template:
raise ValueError(
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
)
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 32,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
num_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
num_inference_steps: int = 50,
sigmas: List[float] = None,
guidance_scale: float = 6.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
not applied.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~HunyuanVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds,
callback_on_step_end_tensor_inputs,
prompt_template,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_template=prompt_template,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
device=device,
max_sequence_length=max_sequence_length,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
if pooled_prompt_embeds is not None:
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_latent_frames,
torch.float32,
device,
generator,
latents,
)
# 6. Prepare guidance condition
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return HunyuanVideoPipelineOutput(frames=video)
@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class HunyuanVideoPipelineOutput(BaseOutput):
r"""
Output class for HunyuanVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -798,7 +798,11 @@ class HunyuanDiTPipeline(DiffusionPipeline):
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
)
style = torch.tensor([0], device=device)
+50
View File
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

Some files were not shown because too many files have changed in this diff Show More