Compare commits

..

63 Commits

Author SHA1 Message Date
sayakpaul 51d1436a16 faster model loading on cuda. 2025-04-21 18:05:39 +05:30
Kenneth Gerald Hamilton 0dec414d5b [train_dreambooth_lora_sdxl.py] Fix the LR Schedulers when num_train_epochs is passed in a distributed training env (#11240)
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-04-21 12:51:03 +05:30
Linoy Tsaban 44eeba07b2 [Flux LoRAs] fix lr scheduler bug in distributed scenarios (#11242)
* add fix

* add fix

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-21 10:08:45 +03:00
YiYi Xu 5873377a66 [Wan2.1-FLF2V] update conversion script (#11365)
update scheuler config in conversion sript
2025-04-18 14:08:44 -10:00
YiYi Xu 5a2e0f715c update output for Hidream transformer (#11366)
up
2025-04-18 14:07:21 -10:00
Kazuki Yoda ef47726e2d Fix: StableDiffusionXLControlNetAdapterInpaintPipeline incorrectly inherited StableDiffusionLoraLoaderMixin (#11357)
Fix: Inherit `StableDiffusionXLLoraLoaderMixin`

`StableDiffusionXLControlNetAdapterInpaintPipeline`
used to incorrectly inherit
`StableDiffusionLoraLoaderMixin`
instead of `StableDiffusionXLLoraLoaderMixin`
2025-04-18 12:46:06 -10:00
YiYi Xu 0021bfa1e1 support Wan-FLF2V (#11353)
* update transformer

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-04-18 10:27:50 -10:00
Marc Sun bbd0c161b5 [BNB] Fix test_moving_to_cpu_throws_warning (#11356)
fix

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-18 09:44:51 +05:30
Yao Matrix eef3d65954 enable 2 test cases on XPU (#11332)
* enable 2 test cases on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Apply style fixes

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-04-17 13:27:41 -10:00
Frank (Haofan) Wang ee6ad51d96 Update controlnet_flux.py (#11350) 2025-04-17 10:05:01 -10:00
Sayak Paul 4397f59a37 [bitsandbytes] improve dtype mismatch handling for bnb + lora. (#11270)
* improve dtype mismatch handling for bnb + lora.

* add a test

* fix and updates

* update
2025-04-17 19:51:49 +05:30
YiYi Xu 056793295c [Hi Dream] follow-up (#11296)
* add
2025-04-17 01:17:44 -10:00
Sayak Paul 29d2afbfe2 [LoRA] Propagate hotswap better (#11333)
* propagate hotswap to other load_lora_weights() methods.

* simplify documentations.

* updates

* propagate to load_lora_into_text_encoder.

* empty commit
2025-04-17 10:35:38 +05:30
Sayak Paul b00a564dac [docs] add note about use_duck_shape in auraflow docs. (#11348)
add note about use_duck_shape in auraflow docs.
2025-04-17 10:25:39 +05:30
Sayak Paul efc9d68b15 [chore] fix lora docs utils (#11338)
fix lora docs utils
2025-04-17 09:25:53 +05:30
nPeppon 3e59d531d1 Fix wrong dtype argument name as torch_dtype (#11346) 2025-04-16 16:00:25 -04:00
Ishan Modi d63e6fccb1 [BUG] fixed _toctree.yml alphabetical ordering (#11277)
update
2025-04-16 09:04:22 -07:00
Dhruv Nair 59f1b7b1c8 Hunyuan I2V fast tests fix (#11341)
* update

* update
2025-04-16 18:40:33 +05:30
Sayak Paul ce1063acfa [docs] add a snippet for compilation in the auraflow docs. (#11327)
* add a snippet for compilation in the auraflow docs.

* include speedups.
2025-04-16 11:12:09 +05:30
Sayak Paul 7212f35de2 [single file] enable telemetry for single file loading when using GGUF. (#11284)
* enable telemetry for single file loading when using GGUF.

* quality
2025-04-16 08:33:52 +05:30
Sayak Paul 3252d7ad11 unpin torch versions for onnx Dockerfile (#11290)
unpin torch versions for onnx
2025-04-16 08:16:38 +05:30
Dhruv Nair b316104ddd Fix Hunyuan I2V for transformers>4.47.1 (#11293)
* update

* update
2025-04-16 07:53:32 +05:30
Álvaro Somoza d3b2699a7f another fix for FlowMatchLCMScheduler forgotten import (#11330)
fix
2025-04-15 13:53:16 -10:00
Sayak Paul 4b868f14c1 post release 0.33.0 (#11255)
* post release

* update

* fix deprecations

* remaining

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-15 06:50:08 -10:00
AstraliteHeart b6156aafe9 Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible (#11297)
* Update pe_selection_index_based_on_dim

* Make pe_selection_index_based_on_dim work with torh.compile

* Fix AuraFlowTransformer2DModel's dpcstring default values

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-15 19:30:30 +05:30
Sayak Paul 7ecfe29160 [docs] fix hidream docstrings. (#11325)
* fix hidream docstrings.

* fix

* empty commit
2025-04-15 16:26:21 +05:30
Yao Matrix 7edace9a05 fix CPU offloading related fail cases on XPU (#11288)
* fix CPU offloading related fail cases on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Apply style fixes

* trigger tests

* test_pipe_same_device_id_offload

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-04-15 09:06:56 +01:00
hlky 6e80d240d3 Fix vae.Decoder prev_output_channel (#11280) 2025-04-15 08:03:50 +01:00
Hameer Abbasi 9352a5ca56 [LoRA] Add LoRA support to AuraFlow (#10216)
* Add AuraFlowLoraLoaderMixin

* Add comments, remove qkv fusion

* Add Tests

* Add AuraFlowLoraLoaderMixin to documentation

* Add Suggested changes

* Change attention_kwargs->joint_attention_kwargs

* Rebasing derp.

* fix

* fix

* Quality fixes.

* make style

* `make fix-copies`

* `ruff check --fix`

* Attept 1 to fix tests.

* Attept 2 to fix tests.

* Attept 3 to fix tests.

* Address review comments.

* Rebasing derp.

* Get more tests passing by copying from Flux. Address review comments.

* `joint_attention_kwargs`->`attention_kwargs`

* Add `lora_scale` property for te LoRAs.

* Make test better.

* Remove useless property.

* Skip TE-only tests for AuraFlow.

* Support LoRA for non-CLIP TEs.

* Restore LoRA tests.

* Undo adding LoRA support for non-CLIP TEs.

* Undo support for TE in AuraFlow LoRA.

* `make fix-copies`

* Sync with upstream changes.

* Remove unneeded stuff.

* Mirror `Lumina2`.

* Skip for MPS.

* Address review comments.

* Remove duplicated code.

* Remove unnecessary code.

* Remove repeated docs.

* Propagate attention.

* Fix TE target modules.

* MPS fix for LoRA tests.

* Unrelated TE LoRA tests fix.

* Fix AuraFlow LoRA tests by applying to the right denoiser layers.

Co-authored-by: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com>

* Apply style fixes

* empty commit

* Fix the repo consistency issues.

* Remove unrelated changes.

* Style.

* Fix `test_lora_fuse_nan`.

* fix quality issues.

* `pytest.xfail` -> `ValueError`.

* Add back `skip_mps`.

* Apply style fixes

* `make fix-copies`

---------

Co-authored-by: Warlord-K <warlordk28@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-15 10:41:28 +05:30
Sayak Paul cefa28f449 [docs] Promote AutoModel usage (#11300)
* docs: promote the usage of automodel.

* bitsandbytes

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-04-15 09:25:40 +05:30
Beinsezii 8819cda6c0 Add skrample section to community_projects.md (#11319)
Update community_projects.md

https://github.com/huggingface/diffusers/discussions/11158#discussioncomment-12681691
2025-04-14 12:12:59 -10:00
hlky dcf836cf47 Use float32 on mps or npu in transformer_hidream_image's rope (#11316) 2025-04-14 20:19:21 +01:00
Álvaro Somoza 1cb73cb19f import for FlowMatchLCMScheduler (#11318)
* add

* fix-copies
2025-04-14 06:28:57 -10:00
Linoy Tsaban ba6008abfe [HiDream] code example (#11317) 2025-04-14 16:19:30 +01:00
Sayak Paul a8f5134c11 [LoRA] support more SDXL loras. (#11292)
* support more SDXL loras.

* update

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-04-14 17:09:59 +05:30
Fanli Lin c7f2d239fe make KolorsPipelineFastTests::test_inference_batch_single_identical pass on XPU (#11313)
adjust diff
2025-04-14 11:02:02 +01:00
Yao Matrix fa1ac50a66 make test_stable_diffusion_karras_sigmas pass on XPU (#11310)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-14 08:15:38 +01:00
Yao Matrix aa541b9fab make KandinskyV22PipelineInpaintCombinedFastTests::test_float16_inference pass on XPU (#11308)
loose expected_max_diff from 5e-1 to 8e-1 to make
KandinskyV22PipelineInpaintCombinedFastTests::test_float16_inference
pass on XPU

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-14 07:49:20 +01:00
Ishan Modi f1f38ffbee [ControlNet] Adds controlnet for SanaTransformer (#11040)
* added controlnet for sana transformer

* improve code quality

* addressed PR comments

* bug fixes

* added test cases

* update

* added dummy objects

* addressed PR comments

* update

* Forcing update

* add to docs

* code quality

* addressed PR comments

* addressed PR comments

* update

* addressed PR comments

* added proper styling

* update

* Revert "added proper styling"

This reverts commit 344ee8a701.

* manually ordered

* Apply suggestions from code review

---------

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
2025-04-13 19:19:39 +05:30
Tuna Tuncer 36538e1135 Fix incorrect tile_latent_min_width calculations (#11305) 2025-04-13 18:21:50 +05:30
Aryan 97e0ef4db4 Hidream refactoring follow ups (#11299)
* HiDream Image

* update

* -einops

* py3.8

* fix -einops

* mixins, offload_seq, option_components

* docs

* Apply style fixes

* trigger tests

* Apply suggestions from code review

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* joint_attention_kwargs -> attention_kwargs, fixes

* fast tests

* -_init_weights

* style tests

* move reshape logic

* update slice 😴

* supports_dduf

* 🤷🏻‍♂️

* Update src/diffusers/models/transformers/transformer_hidream_image.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* address review comments

* update tests

* doc updates

* update

* Update src/diffusers/models/transformers/transformer_hidream_image.py

* Apply style fixes

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-13 09:45:19 +05:30
Adrien B ed41db8525 Update autoencoderkl_allegro.md (#11303)
Correction typo
2025-04-13 09:41:30 +05:30
Nikita Starodubcev ec0b2b3947 flow matching lcm scheduler (#11170)
* add flow matching lcm scheduler
* stochastic sampling
* upscaling for scale-wise generation

* Apply style fixes

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-04-12 11:14:57 -10:00
hlky 0ef29355c9 HiDream Image (#11231)
* HiDream Image


---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-04-11 06:31:34 -10:00
Tuna Tuncer bc261058ee Fix incorrect tile_latent_min_width calculation in AutoencoderKLMochi (#11294) 2025-04-11 19:11:08 +05:30
Sayak Paul 7054a34978 do not use DIFFUSERS_REQUEST_TIMEOUT for notification bot (#11273)
fix to a constant
2025-04-11 14:19:46 +05:30
Sayak Paul 511d738121 [CI] relax tolerance for unclip further (#11268)
relax tolerance for unclip further.
2025-04-11 14:06:52 +05:30
Sayak Paul ea5a6a8b7c [Tests] Cleanup lora tests utils (#11276)
* start cleaning up lora test utils for reusability

* update

* updates

* updates
2025-04-10 15:50:34 +05:30
hlky b8093e6665 Fix LTX 0.9.5 single file (#11271) 2025-04-10 07:06:13 +01:00
Yuqian Hong e121d0ef67 [BUG] Fix convert_vae_pt_to_diffusers bug (#11078)
* fix attention

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-10 06:59:45 +01:00
Yao Matrix 31c4f24fc1 make test_instant_style_multiple_masks pass on XPU (#11266)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-10 06:23:00 +01:00
xieofxie 0efdf411fb add onnxruntime-qnn & onnxruntime-cann (#11269)
Co-authored-by: hualxie <hualxie@microsoft.com>
2025-04-10 06:00:23 +01:00
Yao Matrix 450dc48a2c make test_dict_tuple_outputs_equivalent pass on XPU (#11265)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-10 05:56:28 +01:00
Yao Matrix 77b4f66b9e make test_stable_diffusion_inpaint_fp16 pass on XPU (#11264)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-10 05:55:54 +01:00
Yao Matrix 68663f8a17 fix test_vanilla_funetuning failure on XPU and A100 (#11263)
* fix test_vanilla_funetuning failure on XPU and A100

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* change back to 5e-2

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-10 05:55:07 +01:00
Sayak Paul ffda8735be [LoRA] support musubi wan loras. (#11243)
* support musubi wan loras.

* Update src/diffusers/loaders/lora_conversion_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* support i2v loras from musubi too.

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-04-10 09:50:22 +05:30
YiYi Xu 0706786e53 fix wan ftfy import (#11262) 2025-04-09 09:08:34 -10:00
Sayak Paul 5b27f8aba8 fix consisid imports (#11254)
* fix consisid imports

* fix opencv import

* fix
2025-04-09 18:49:32 +05:30
Sayak Paul d1387ecee5 fix timeout constant (#11252)
* fix timeout constant

* style

* fix
2025-04-09 17:48:52 +05:30
Ilya Drobyshevskiy 6a7c2d0afa fix flux controlnet bug (#11152)
Before this if txt_ids was 3d tensor, line with txt_ids[:1] concat txt_ids by batch dim. Now we first check that txt_ids is 2d tensor (or take first batch element) and then concat by token dim
2025-04-09 13:01:07 +01:00
Dhruv Nair edc154da09 Update Ruff to latest Version (#10919)
* update

* update

* update

* update
2025-04-09 16:51:34 +05:30
hlky 552cd32058 [docs] AutoModel (#11250)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-09 16:42:23 +05:30
Yao Matrix c36c745ceb fix FluxReduxSlowTests::test_flux_redux_inference case failure on XPU (#11245)
* loose test_float16_inference's tolerance from 5e-2 to 6e-2, so XPU can
pass UT

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix test_pipeline_flux_redux fail on XPU

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-04-09 11:41:15 +01:00
152 changed files with 7009 additions and 1157 deletions
+3 -3
View File
@@ -28,9 +28,9 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3 -m uv pip install --no-cache-dir \
torch==2.1.2 \
torchvision==0.16.2 \
torchaudio==2.1.2 \
torch \
torchvision \
torchaudio\
onnxruntime \
--extra-index-url https://download.pytorch.org/whl/cpu && \
python3 -m uv pip install --no-cache-dir \
+43 -33
View File
@@ -175,7 +175,7 @@
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
@@ -265,19 +265,23 @@
sections:
- local: api/models/overview
title: Overview
- local: api/models/auto_model
title: AutoModel
- sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_union
title: ControlNetUnionModel
- local: api/models/controlnet_flux
title: FluxControlNetModel
- local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sana
title: SanaControlNetModel
- local: api/models/controlnet_sd3
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
- local: api/models/controlnet_union
title: ControlNetUnionModel
title: ControlNets
- sections:
- local: api/models/allegro_transformer3d
@@ -286,30 +290,32 @@
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/consisid_transformer3d
title: ConsisIDTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
- local: api/models/cogview4_transformer2d
title: CogView4Transformer2DModel
- local: api/models/consisid_transformer3d
title: ConsisIDTransformer3DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/lumina2_transformer2d
title: Lumina2Transformer2DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
title: Lumina2Transformer2DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/omnigen_transformer
@@ -318,10 +324,10 @@
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
@@ -336,10 +342,10 @@
title: StableCascadeUNet
- local: api/models/unet
title: UNet1DModel
- local: api/models/unet2d
title: UNet2DModel
- local: api/models/unet2d-cond
title: UNet2DConditionModel
- local: api/models/unet2d
title: UNet2DModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/unet-motion
@@ -348,6 +354,10 @@
title: UViT2DModel
title: UNets
- sections:
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
title: AutoencoderDC
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/autoencoderkl_allegro
@@ -364,10 +374,6 @@
title: AutoencoderKLMochi
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
title: AutoencoderDC
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
@@ -420,6 +426,8 @@
title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnet_sana
title: ControlNet-Sana
- local: api/pipelines/controlnetxs
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
@@ -444,6 +452,8 @@
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
@@ -511,40 +521,40 @@
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
- local: api/pipelines/stable_diffusion/depth2img
title: Depth-to-image
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
- local: api/pipelines/stable_diffusion/image_variation
title: Image variation
- local: api/pipelines/stable_diffusion/img2img
title: Image-to-image
- local: api/pipelines/stable_diffusion/svd
title: Image-to-video
- local: api/pipelines/stable_diffusion/inpaint
title: Inpainting
- local: api/pipelines/stable_diffusion/depth2img
title: Depth-to-image
- local: api/pipelines/stable_diffusion/image_variation
title: Image variation
- local: api/pipelines/stable_diffusion/k_diffusion
title: K-Diffusion
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
title: Safe Stable Diffusion
- local: api/pipelines/stable_diffusion/sdxl_turbo
title: SDXL Turbo
- local: api/pipelines/stable_diffusion/stable_diffusion_2
title: Stable Diffusion 2
- local: api/pipelines/stable_diffusion/stable_diffusion_3
title: Stable Diffusion 3
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
title: Stable Diffusion XL
- local: api/pipelines/stable_diffusion/sdxl_turbo
title: SDXL Turbo
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/upscale
title: Super-resolution
- local: api/pipelines/stable_diffusion/k_diffusion
title: K-Diffusion
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
- local: api/pipelines/stable_diffusion/adapter
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
+14
View File
@@ -20,10 +20,13 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
@@ -56,6 +59,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
## LTXVideoLoraLoaderMixin
@@ -73,6 +79,14 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
## CogView4LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
## WanLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
+29
View File
@@ -0,0 +1,29 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoModel
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
```python
from diffusers import AutoModel, AutoPipelineForText2Image
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
```
## AutoModel
[[autodoc]] AutoModel
- all
- from_pretrained
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLAllegro
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLAllegro
@@ -0,0 +1,29 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SanaControlNetModel
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
## SanaControlNetModel
[[autodoc]] SanaControlNetModel
## SanaControlNetOutput
[[autodoc]] models.controlnets.controlnet_sana.SanaControlNetOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HiDreamImageTransformer2DModel
A Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai).
The model can be loaded with the following code snippet.
```python
from diffusers import HiDreamImageTransformer2DModel
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+17
View File
@@ -89,6 +89,23 @@ image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
## Support for `torch.compile()`
AuraFlow can be compiled with `torch.compile()` to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from [here](https://pytorch.org/). The snippet below shows the changes needed to enable this:
```diff
+ torch.fx.experimental._config.use_duck_shape = False
+ pipeline.transformer = torch.compile(
pipeline.transformer, fullgraph=True, dynamic=True
)
```
Specifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
This enables from 100% (on low resolutions) to a 30% (on 1536x1536 resolution) speed improvements.
Thanks to [AstraliteHeart](https://github.com/huggingface/diffusers/pull/11297/) who helped us rewrite the [`AuraFlowTransformer2DModel`] class so that the above works for different resolutions ([PR](https://github.com/huggingface/diffusers/pull/11297/)).
## AuraFlowPipeline
[[autodoc]] AuraFlowPipeline
@@ -0,0 +1,36 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
## SanaControlNetPipeline
[[autodoc]] SanaControlNetPipeline
- all
- __call__
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
+43
View File
@@ -0,0 +1,43 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HiDreamImage
[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Available models
The following models are available for the [`HiDreamImagePipeline`](text-to-image) pipeline:
| Model name | Description |
|:---|:---|
| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - |
| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - |
| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - |
## HiDreamImagePipeline
[[autodoc]] HiDreamImagePipeline
- all
- __call__
## HiDreamImagePipelineOutput
[[autodoc]] pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput
+54
View File
@@ -133,6 +133,60 @@ output = pipe(
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### First and Last Frame Interpolation
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipe(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```
### Video to Video Generation
```python
+4
View File
@@ -83,4 +83,8 @@ Happy exploring, and thank you for being part of the Diffusers community!
<td><a href="https://github.com/suzukimain/auto_diffusers"> Model Search </a></td>
<td>Search models on Civitai and Hugging Face</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/beinsezii/skrample"> Skrample </a></td>
<td>Fully modular scheduler functions with 1st class diffusers integration.</td>
</tr>
</table>
+16 -16
View File
@@ -49,7 +49,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel
from diffusers import AutoModel
from transformers import T5EncoderModel
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
@@ -63,7 +63,7 @@ text_encoder_2_8bit = T5EncoderModel.from_pretrained(
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
transformer_8bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
@@ -74,7 +74,7 @@ transformer_8bit = FluxTransformer2DModel.from_pretrained(
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
```diff
transformer_8bit = FluxTransformer2DModel.from_pretrained(
transformer_8bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
@@ -133,7 +133,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel
from diffusers import AutoModel
from transformers import T5EncoderModel
quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,)
@@ -147,7 +147,7 @@ text_encoder_2_4bit = T5EncoderModel.from_pretrained(
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,)
transformer_4bit = FluxTransformer2DModel.from_pretrained(
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
@@ -158,7 +158,7 @@ transformer_4bit = FluxTransformer2DModel.from_pretrained(
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
```diff
transformer_4bit = FluxTransformer2DModel.from_pretrained(
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
@@ -217,11 +217,11 @@ print(model.get_memory_footprint())
Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
from diffusers import AutoModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained(
model_4bit = AutoModel.from_pretrained(
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
```
@@ -243,13 +243,13 @@ An "outlier" is a hidden state value greater than a certain threshold, and these
To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:
```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
from diffusers import AutoModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_threshold=10,
)
model_8bit = FluxTransformer2DModel.from_pretrained(
model_8bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
@@ -305,7 +305,7 @@ NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel
from diffusers import AutoModel
from transformers import T5EncoderModel
quant_config = TransformersBitsAndBytesConfig(
@@ -325,7 +325,7 @@ quant_config = DiffusersBitsAndBytesConfig(
bnb_4bit_quant_type="nf4",
)
transformer_4bit = FluxTransformer2DModel.from_pretrained(
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
@@ -343,7 +343,7 @@ Nested quantization is a technique that can save additional memory at no additio
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel
from diffusers import AutoModel
from transformers import T5EncoderModel
quant_config = TransformersBitsAndBytesConfig(
@@ -363,7 +363,7 @@ quant_config = DiffusersBitsAndBytesConfig(
bnb_4bit_use_double_quant=True,
)
transformer_4bit = FluxTransformer2DModel.from_pretrained(
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
@@ -379,7 +379,7 @@ Once quantized, you can dequantize a model to its original precision, but this m
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel
from diffusers import AutoModel
from transformers import T5EncoderModel
quant_config = TransformersBitsAndBytesConfig(
@@ -399,7 +399,7 @@ quant_config = DiffusersBitsAndBytesConfig(
bnb_4bit_use_double_quant=True,
)
transformer_4bit = FluxTransformer2DModel.from_pretrained(
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
+12 -9
View File
@@ -26,13 +26,13 @@ The example below only quantizes the weights to int8.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
transformer = AutoModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
@@ -99,10 +99,10 @@ To serialize a quantized model in a given dtype, first load the model with the d
```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig
from diffusers import AutoModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
@@ -115,9 +115,9 @@ To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] me
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers import FluxPipeline, AutoModel
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
transformer = AutoModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
@@ -131,10 +131,10 @@ If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, c
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
@@ -146,10 +146,13 @@ transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, m
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer = AutoModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```
> [!TIP]
> The [`AutoModel`] API is supported for PyTorch >= 2.6 as shown in the examples below.
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
+3
View File
@@ -163,6 +163,9 @@ Models are initiated with the [`~ModelMixin.from_pretrained`] method which also
>>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
```
> [!TIP]
> Use the [`AutoModel`] API to automatically select a model class if you're unsure of which one to use.
To access the model parameters, call `model.config`:
```py
+2 -2
View File
@@ -31,10 +31,10 @@ To adapt your text-to-image model for inpainting, you'll need to change the numb
Initialize a [`UNet2DConditionModel`] with the pretrained text-to-image model weights, and change `in_channels` to 9. Changing the number of `in_channels` means you need to set `ignore_mismatched_sizes=True` and `low_cpu_mem_usage=False` to avoid a size mismatch error because the shape is different now.
```py
from diffusers import UNet2DConditionModel
from diffusers import AutoModel
model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(
unet = AutoModel.from_pretrained(
model_id,
subfolder="unet",
in_channels=9,
@@ -165,10 +165,10 @@ flush()
Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
```py
from diffusers import FluxTransformer2DModel
from diffusers import AutoModel
import torch
transformer = FluxTransformer2DModel.from_pretrained(
transformer = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
device_map="auto",
@@ -32,9 +32,9 @@ The denoiser checkpoint can also have multiple shards and supports inference tha
For example, let's save a sharded checkpoint for the [SDXL UNet](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/unet):
```python
from diffusers import UNet2DConditionModel
from diffusers import AutoModel
unet = UNet2DConditionModel.from_pretrained(
unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
)
unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
@@ -43,10 +43,10 @@ unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]:
```python
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
from diffusers import AutoModel, StableDiffusionXLPipeline
import torch
unet = UNet2DConditionModel.from_pretrained(
unet = AutoModel.from_pretrained(
"sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
@@ -134,7 +134,7 @@ The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads L
- the LoRA weights don't have separate identifiers for the UNet and text encoder
- the LoRA weights have separate identifiers for the UNet and text encoder
To directly load (and save) a LoRA adapter at the *model-level*, use [`~PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
To directly load (and save) a LoRA adapter at the *model-level*, use [`~loaders.PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`~loaders.PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
Use the `weight_name` parameter to specify the specific weight file and the `prefix` parameter to filter for the appropriate state dicts (`"unet"` in this case) to load.
@@ -155,7 +155,7 @@ image
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
</div>
Save an adapter with [`~PeftAdapterMixin.save_lora_adapter`].
Save an adapter with [`~loaders.PeftAdapterMixin.save_lora_adapter`].
To unload the LoRA weights, use the [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
@@ -66,10 +66,10 @@ Let's dive deeper into what these steps entail.
1. Load a UNet that corresponds to the UNet in the LoRA checkpoint. In this case, both LoRAs use the SDXL UNet as their base model.
```python
from diffusers import UNet2DConditionModel
from diffusers import AutoModel
import torch
unet = UNet2DConditionModel.from_pretrained(
unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
@@ -136,7 +136,7 @@ feng_peft_model.load_state_dict(original_state_dict, strict=True)
```python
from peft import PeftModel
base_unet = UNet2DConditionModel.from_pretrained(
base_unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -1915,17 +1915,22 @@ def main(args):
free_memory()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1949,7 +1954,6 @@ def main(args):
lr_scheduler,
)
else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
)
@@ -1961,8 +1965,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -80,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
class MarigoldDepthOutput(BaseOutput):
@@ -33,7 +33,6 @@ from diffusers import DiffusionPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
@@ -300,7 +299,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class StableDiffusionXLControlNetAdapterInpaintPipeline(
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = logging.getLogger(__name__)
+2 -2
View File
@@ -51,7 +51,7 @@ from diffusers import (
FlowMatchEulerDiscreteScheduler,
FluxTransformer2DModel,
)
from diffusers.models.controlnet_flux import FluxControlNetModel
from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+19 -8
View File
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -1407,17 +1407,22 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1444,8 +1449,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+1 -1
View File
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -1524,17 +1524,22 @@ def main(args):
free_memory()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1561,8 +1566,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -48,7 +48,7 @@ import diffusers
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
Lumina2Text2ImgPipeline,
Lumina2Pipeline,
Lumina2Transformer2DModel,
)
from diffusers.optimization import get_scheduler
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -898,7 +898,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
pipeline = Lumina2Text2ImgPipeline.from_pretrained(
pipeline = Lumina2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
revision=args.revision,
@@ -990,7 +990,7 @@ def main(args):
text_encoder.to(dtype=torch.bfloat16)
# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = Lumina2Text2ImgPipeline.from_pretrained(
text_encoding_pipeline = Lumina2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
@@ -1034,7 +1034,7 @@ def main(args):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Lumina2Text2ImgPipeline.save_lora_weights(
Lumina2Pipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
)
@@ -1050,7 +1050,7 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
lora_state_dict = Lumina2Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
@@ -1473,7 +1473,7 @@ def main(args):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
pipeline = Lumina2Text2ImgPipeline.from_pretrained(
pipeline = Lumina2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
@@ -1503,14 +1503,14 @@ def main(args):
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
Lumina2Text2ImgPipeline.save_lora_weights(
Lumina2Pipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
)
# Final inference
# Load previous pipeline
pipeline = Lumina2Text2ImgPipeline.from_pretrained(
pipeline = Lumina2Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -1523,17 +1523,22 @@ def main(args):
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1550,7 +1555,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -54,7 +54,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.33.0.dev0")
check_min_version("0.34.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -0,0 +1,216 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from diffusers import (
SanaControlNetModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
def main(args):
file_path = args.orig_ckpt_path
all_state_dict = torch.load(file_path, weights_only=True)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
# ControlNet Input Projection.
converted_state_dict["input_block.weight"] = state_dict.pop("controlnet.0.before_proj.weight")
converted_state_dict["input_block.bias"] = state_dict.pop("controlnet.0.before_proj.bias")
for depth in range(7):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"controlnet.{depth}.copied_block.scale_shift_table"
)
# Linear Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"controlnet.{depth}.copied_block.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"controlnet.{depth}.copied_block.attn.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
f"controlnet.{depth}.copied_block.mlp.point_conv.conv.weight"
)
# Cross-attention.
q = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(
state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"controlnet.{depth}.copied_block.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"controlnet.{depth}.copied_block.cross_attn.proj.bias"
)
# ControlNet After Projection
converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(
f"controlnet.{depth}.after_proj.weight"
)
converted_state_dict[f"controlnet_blocks.{depth}.bias"] = state_dict.pop(f"controlnet.{depth}.after_proj.bias")
# ControlNet
with CTX():
controlnet = SanaControlNetModel(
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
sample_size=args.image_size // 32,
interpolation_scale=interpolation_scale[args.image_size],
)
if is_accelerate_available():
load_model_dict_into_meta(controlnet, converted_state_dict)
else:
controlnet.load_state_dict(converted_state_dict, strict=True, assign=True)
num_model_params = sum(p.numel() for p in controlnet.parameters())
print(f"Total number of controlnet parameters: {num_model_params}")
controlnet = controlnet.to(weight_dtype)
controlnet.load_state_dict(converted_state_dict, strict=True)
print(f"Saving Sana ControlNet in Diffusers format in {args.dump_path}.")
controlnet.save_pretrained(args.dump_path)
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[512, 1024, 2048, 4096],
required=False,
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type",
default="SanaMS_1600M_P1_ControlNet_D7",
type=str,
choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"],
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args()
model_kwargs = {
"SanaMS_1600M_P1_ControlNet_D7": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 7,
},
"SanaMS_600M_P1_ControlNet_D7": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 7,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)
+18 -2
View File
@@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
resnets = [
key
for key in down_blocks[i]
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
]
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
@@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
paths = renew_vae_attention_paths(attentions)
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
@@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
key
for key in up_blocks[block_id]
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
]
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
@@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
paths = renew_vae_attention_paths(attentions)
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
+43 -2
View File
@@ -39,6 +39,24 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# for the FLF2V model
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
@@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
elif model_type == "Wan-FLF2V-14B-720P":
config = {
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": 257 * 2,
},
}
return config
@@ -393,11 +433,12 @@ if __name__ == "__main__":
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
)
if "I2V" in args.model_type:
if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
+2 -1
View File
@@ -142,6 +142,7 @@ _deps = [
"urllib3<=2.0.0",
"black",
"phonemizer",
"opencv-python",
]
# this is a lookup table with items like:
@@ -268,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.33.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.34.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+34 -3
View File
@@ -1,4 +1,4 @@
__version__ = "0.33.0.dev0"
__version__ = "0.34.0.dev0"
from typing import TYPE_CHECKING
@@ -14,6 +14,7 @@ from .utils import (
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
@@ -170,6 +171,7 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
@@ -188,6 +190,7 @@ else:
"OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaControlNetModel",
"SanaTransformer2DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
@@ -265,6 +268,7 @@ else:
"EulerDiscreteScheduler",
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"FlowMatchLCMScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
@@ -352,7 +356,6 @@ else:
"CogView3PlusPipeline",
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"CycleDiffusionPipeline",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
@@ -368,6 +371,7 @@ else:
"FluxInpaintPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
@@ -426,6 +430,7 @@ else:
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintPipeline",
@@ -518,6 +523,19 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
_import_structure["utils.dummy_torch_and_transformers_and_opencv_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_opencv_objects) if not name.startswith("_")
]
else:
_import_structure["pipelines"].extend(["ConsisIDPipeline"])
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
@@ -748,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
@@ -766,6 +785,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaControlNetModel,
SanaTransformer2DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
@@ -841,6 +861,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
FlowMatchLCMScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
@@ -909,7 +930,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView3PlusPipeline,
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
CycleDiffusionPipeline,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
@@ -925,6 +945,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxInpaintPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
@@ -983,6 +1004,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintPipeline,
@@ -1088,6 +1110,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
else:
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_opencv_objects import * # noqa F403
else:
from .pipelines import ConsisIDPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
@@ -49,4 +49,5 @@ deps = {
"urllib3": "urllib3<=2.0.0",
"black": "black",
"phonemizer": "phonemizer",
"opencv-python": "opencv-python",
}
+2
View File
@@ -65,6 +65,7 @@ if is_torch_available():
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
@@ -103,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
+1 -1
View File
@@ -526,7 +526,7 @@ class FluxIPAdapterMixin:
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
dtype=image_encoder_dtype,
torch_dtype=image_encoder_dtype,
)
.to(self.device)
.eval()
+80 -1
View File
@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
# check if state_dict contains both patterns
contains_sgm_patterns = False
contains_not_sgm_patterns = False
for key in all_keys:
if any(p in key for p in sgm_patterns):
contains_sgm_patterns = True
elif any(p in key for p in not_sgm_patterns):
contains_not_sgm_patterns = True
# if state_dict contains both patterns, remove sgm
# we can then return state_dict immediately
if contains_sgm_patterns and contains_not_sgm_patterns:
for key in all_keys:
if any(p in key for p in sgm_patterns):
state_dict.pop(key)
return state_dict
# 2. check if needs remapping, if not return original dict
is_in_sgm_format = False
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
)
new_state_dict[new_key] = state_dict.pop(key)
if len(state_dict) > 0:
if state_dict:
raise ValueError("At this point all state dict entries have to be converted.")
return new_state_dict
@@ -1608,3 +1626,64 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_musubi_wan_lora_to_diffusers(state_dict):
# https://github.com/kohya-ss/musubi-tuner
converted_state_dict = {}
original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
def get_alpha_scales(down_weight, key):
rank = down_weight.shape[0]
alpha = original_state_dict.pop(key + ".alpha").item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for i in range(num_blocks):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
# FFN
for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
if len(original_state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
File diff suppressed because it is too large Load Diff
+1
View File
@@ -52,6 +52,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
@@ -21,6 +21,7 @@ import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
@@ -260,6 +261,11 @@ class FromOriginalModelMixin:
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
@@ -278,6 +284,7 @@ class FromOriginalModelMixin:
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
user_agent=user_agent,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
+37 -3
View File
@@ -177,6 +177,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
@@ -404,13 +405,16 @@ def load_single_file_checkpoint(
local_files_only=None,
revision=None,
disable_mmap=False,
user_agent=None,
):
if user_agent is None:
user_agent = {"file_type": "single_file", "framework": "pytorch"}
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
user_agent = {"file_type": "single_file", "framework": "pytorch"}
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
@@ -638,7 +642,9 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
model_type = "ltx-video-0.9.5"
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"
@@ -2403,13 +2409,41 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"last_scale_shift_table": "scale_shift_table",
}
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
for key in list(converted_state_dict.keys()):
+4
View File
@@ -49,6 +49,7 @@ if is_torch_available():
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DMultiControlNetModel",
]
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
@@ -76,6 +77,7 @@ if is_torch_available():
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
@@ -133,6 +135,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
MultiControlNetUnionModel,
SanaControlNetModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SparseControlNetModel,
@@ -151,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,
@@ -829,7 +829,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
@@ -1285,7 +1285,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
@@ -887,7 +887,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
return self.tiled_decode(z, return_dict=return_dict)
@@ -909,7 +909,7 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
+1 -1
View File
@@ -255,7 +255,7 @@ class Decoder(nn.Module):
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
prev_output_channel=prev_output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
@@ -9,6 +9,7 @@ if is_torch_available():
HunyuanDiT2DControlNetModel,
HunyuanDiT2DMultiControlNetModel,
)
from .controlnet_sana import SanaControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
from .controlnet_sparsectrl import (
SparseControlNetConditioningEmbedding,
@@ -298,15 +298,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
@@ -320,6 +311,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
img_ids = img_ids[0]
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
@@ -430,7 +430,7 @@ class FluxMultiControlNetModel(ModelMixin):
) -> Union[FluxControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
if len(self.nets) == 1:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
@@ -454,17 +454,18 @@ class FluxMultiControlNetModel(ModelMixin):
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
@@ -0,0 +1,290 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..transformers.sana_transformer import SanaTransformerBlock
from .controlnet import zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class SanaControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config
def __init__(
self,
in_channels: int = 32,
out_channels: Optional[int] = 32,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
num_layers: int = 7,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
pos_embed_type="sincos" if interpolation_scale is not None else None,
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
SanaTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
self.input_block = zero_module(nn.Linear(inner_dim, inner_dim))
for _ in range(len(self.transformer_blocks)):
controlnet_block = nn.Linear(inner_dim, inner_dim)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size, num_channels, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
# 2. Transformer blocks
block_res_samples = ()
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
)
block_res_samples = block_res_samples + (hidden_states,)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
)
block_res_samples = block_res_samples + (hidden_states,)
# 3. ControlNet blocks
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:
return (controlnet_block_res_samples,)
return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
+2 -2
View File
@@ -38,7 +38,7 @@ if is_transformers_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def text_encoder_attn_modules(text_encoder):
def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules
def text_encoder_mlp_modules(text_encoder):
def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+47 -1
View File
@@ -18,7 +18,7 @@ import importlib
import inspect
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,6 +38,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerator_device,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -304,6 +305,51 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index
# Taken from
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
def _expand_device_map(device_map, param_names):
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
}
if not len(accelerator_device_map):
return
total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
total_byte_count[device] += param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, byte_count in total_byte_count.items():
if device.type == "cuda":
index = device.index if device.index is not None else torch.cuda.current_device()
device_memory = torch.cuda.mem_get_info(index)[0]
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
# Allocate memory
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
+25
View File
@@ -63,7 +63,9 @@ from ..utils.hub_utils import (
populate_model_card,
)
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
@@ -1374,6 +1376,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return super().float(*args)
# Taken from `transformers`.
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
def get_parameter_or_buffer(self, target: str):
"""
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
of the model.
"""
try:
return self.get_parameter(target)
except AttributeError:
pass
try:
return self.get_buffer(target)
except AttributeError:
pass
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
@classmethod
def _load_pretrained_model(
cls,
@@ -1410,6 +1430,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = None
error_msgs = []
# Optionally, warmup cuda to load the weights much faster on devices
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -21,6 +21,7 @@ if is_torch_available():
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
@@ -13,15 +13,15 @@
# limitations under the License.
from typing import Dict, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
@@ -74,15 +74,23 @@ class AuraFlowPatchEmbed(nn.Module):
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
# Calculate the top-left corner indices for the centered patch grid
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
return original_pe_indexes.flatten()
# Generate the row and column indices for the desired patch grid
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
# Create a 2D grid of indices
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
# Convert the 2D grid indices to flattened 1D indices
selected_indices = (row_indices * w_max + col_indices).flatten()
return selected_indices
def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
@@ -160,14 +168,20 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states)
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -223,10 +237,15 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -236,7 +255,9 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -254,7 +275,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -262,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
num_single_dit_layers (`int`, *optional*, defaults to 4):
num_single_dit_layers (`int`, *optional*, defaults to 32):
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
representations.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
out_channels (`int`, defaults to 16): Number of output channels.
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
out_channels (`int`, defaults to 4): Number of output channels.
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
"""
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
@@ -449,8 +470,24 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -474,7 +511,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
attention_kwargs=attention_kwargs,
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -491,7 +531,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
)
else:
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
combined_hidden_states = block(
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
)
hidden_states = combined_hidden_states[:, encoder_seq_len:]
@@ -512,6 +554,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
@@ -483,6 +483,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
@@ -546,7 +547,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
# 2. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
@@ -557,9 +558,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
post_patch_height,
post_patch_width,
)
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
else:
for block in self.transformer_blocks:
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
attention_mask,
@@ -569,6 +572,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
post_patch_height,
post_patch_width,
)
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
# 3. Normalization
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
@@ -0,0 +1,922 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HiDreamImageFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
class HiDreamImagePooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor:
return self.pooled_embedder(pooled_embed)
class HiDreamImageTimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
class HiDreamImageOutEmbed(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1)
hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
hidden_states = self.linear(hidden_states)
return hidden_states
class HiDreamImagePatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
def forward(self, latent):
latent = self.proj(latent)
return latent
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
is_mps = pos.device.type == "mps"
is_npu = pos.device.type == "npu"
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
class HiDreamImageEmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
@maybe_allow_in_graph
class HiDreamAttention(Attention):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor=None,
out_dim: int = None,
single: bool = False,
):
super(Attention, self).__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
self.to_q = nn.Linear(query_dim, self.inner_dim)
self.to_k = nn.Linear(self.inner_dim, self.inner_dim)
self.to_v = nn.Linear(self.inner_dim, self.inner_dim)
self.to_out = nn.Linear(self.inner_dim, self.out_dim)
self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
if not single:
self.to_q_t = nn.Linear(query_dim, self.inner_dim)
self.to_k_t = nn.Linear(self.inner_dim, self.inner_dim)
self.to_v_t = nn.Linear(self.inner_dim, self.inner_dim)
self.to_out_t = nn.Linear(self.inner_dim, self.out_dim)
self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
self.set_processor(processor)
def forward(
self,
norm_hidden_states: torch.Tensor,
hidden_states_masks: torch.Tensor = None,
norm_encoder_hidden_states: torch.Tensor = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
return self.processor(
self,
hidden_states=norm_hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
class HiDreamAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn: HiDreamAttention,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
dtype = hidden_states.dtype
batch_size = hidden_states.shape[0]
query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
value_i = attn.to_v(hidden_states)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if hidden_states_masks is not None:
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
value_t = attn.to_v_t(encoder_hidden_states)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
query, key = apply_rope(query, key, image_rotary_emb)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = F.scaled_dot_product_attention(
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = "softmax"
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# print(bsz, seq_len, h)
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
):
super().__init__()
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
self.experts = nn.ModuleList(
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
)
self.gate = MoEGate(
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
# y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
return expert_cache
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
@maybe_allow_in_graph
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
# 1. Attention
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor=HiDreamAttnProcessor(),
single=True,
)
# 3. Feed-forward
self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim=dim,
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
wtype = hidden_states.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[
:, None
].chunk(6, dim=-1)
# 1. MM-Attention
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_hidden_states,
hidden_states_masks,
image_rotary_emb=image_rotary_emb,
)
hidden_states = gate_msa_i * attn_output_i + hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
hidden_states = ff_output_i + hidden_states
return hidden_states
@maybe_allow_in_graph
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True))
# 1. Attention
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.norm1_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor=HiDreamAttnProcessor(),
single=False,
)
# 3. Feed-forward
self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim=dim,
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
self.norm3_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
wtype = hidden_states.dtype
(
shift_msa_i,
scale_msa_i,
gate_msa_i,
shift_mlp_i,
scale_mlp_i,
gate_mlp_i,
shift_msa_t,
scale_msa_t,
gate_msa_t,
shift_mlp_t,
scale_mlp_t,
gate_mlp_t,
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
# 1. MM-Attention
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_hidden_states,
hidden_states_masks,
norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = gate_msa_i * attn_output_i + hidden_states
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
hidden_states = ff_output_i + hidden_states
encoder_hidden_states = ff_output_t + encoder_hidden_states
return hidden_states, encoder_hidden_states
class HiDreamBlock(nn.Module):
def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]):
super().__init__()
self.block = block
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
self.x_embedder = HiDreamImagePatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
out_channels=self.inner_dim,
)
self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamBlock(
HiDreamImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
)
for _ in range(num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamBlock(
HiDreamImageSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
)
for _ in range(num_single_layers)
]
)
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
self.gradient_checkpointing = False
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
B, S, F = x.shape
C = F // (self.config.patch_size * self.config.patch_size)
x = (
x.reshape(B, S, self.config.patch_size, self.config.patch_size, C)
.permute(0, 4, 1, 2, 3)
.reshape(B, C, S, self.config.patch_size * self.config.patch_size)
)
else:
x_arr = []
p1 = self.config.patch_size
p2 = self.config.patch_size
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
t = x[i, : pH * pW].reshape(1, pH, pW, -1)
F_token = t.shape[-1]
C = F_token // (p1 * p2)
t = t.reshape(1, pH, pW, p1, p2, C)
t = t.permute(0, 5, 1, 3, 2, 4)
t = t.reshape(1, C, pH * p1, pW * p2)
x_arr.append(t)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, hidden_states):
batch_size, channels, height, width = hidden_states.shape
patch_size = self.config.patch_size
patch_height, patch_width = height // patch_size, width // patch_size
device = hidden_states.device
dtype = hidden_states.dtype
# create img_sizes
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
# create hidden_states_masks
if hidden_states.shape[-2] != hidden_states.shape[-1]:
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
hidden_states_masks[:, : patch_height * patch_width] = 1.0
else:
hidden_states_masks = None
# create img_ids
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
row_indices = torch.arange(patch_height, device=device)[:, None]
col_indices = torch.arange(patch_width, device=device)[None, :]
img_ids[..., 1] = img_ids[..., 1] + row_indices
img_ids[..., 2] = img_ids[..., 2] + col_indices
img_ids = img_ids.reshape(patch_height * patch_width, -1)
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
img_ids_pad[: patch_height * patch_width, :] = img_ids
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
else:
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
# patchify hidden_states
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
out = torch.zeros(
(batch_size, channels, self.max_seq, patch_size * patch_size),
dtype=dtype,
device=device,
)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height * patch_width, patch_size * patch_size
)
out[:, :, 0 : patch_height * patch_width] = hidden_states
hidden_states = out
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch_size, self.max_seq, patch_size * patch_size * channels
)
else:
# Handle square latents
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
hidden_states = hidden_states.reshape(
batch_size, patch_height * patch_width, patch_size * patch_size * channels
)
return hidden_states, hidden_states_masks, img_sizes, img_ids
def forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_ids: Optional[torch.Tensor] = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.34.0", deprecation_message)
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
elif hidden_states_masks is not None and hidden_states.ndim != 3:
raise ValueError(
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# Patchify the input
if hidden_states_masks is None:
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
# Embed the hidden states
hidden_states = self.x_embedder(hidden_states)
# 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(encoder_hidden_states_t5)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1]
+ encoder_hidden_states[-2].shape[1]
+ encoder_hidden_states[0].shape[1],
3,
device=img_ids.device,
dtype=img_ids.dtype,
)
ids = torch.cat((img_ids, txt_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat(
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
cur_encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states, initial_encoder_hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=cur_encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if hidden_states_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=hidden_states_masks.device,
dtype=hidden_states_masks.dtype,
)
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
None,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=None,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, temb)
output = self.unpatchify(output, img_sizes, self.training)
if hidden_states_masks is not None:
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -49,8 +49,10 @@ class WanAttnProcessor2_0:
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -108,14 +110,23 @@ class WanAttnProcessor2_0:
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
@@ -130,6 +141,7 @@ class WanTimeTextImageEmbedding(nn.Module):
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
@@ -141,7 +153,7 @@ class WanTimeTextImageEmbedding(nn.Module):
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
@@ -350,6 +362,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()
@@ -368,6 +381,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
+25 -4
View File
@@ -10,6 +10,7 @@ from ..utils import (
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_sentencepiece_available,
is_torch_available,
is_torch_npu_available,
@@ -155,7 +156,6 @@ else:
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -221,6 +221,7 @@ else:
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
@@ -280,7 +281,7 @@ else:
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"]
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline", "SanaControlNetPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
@@ -414,6 +415,18 @@ else:
"KolorsImg2ImgPipeline",
]
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import (
dummy_torch_and_transformers_and_opencv_objects,
)
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
else:
_import_structure["consisid"] = ["ConsisIDPipeline"]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
@@ -512,7 +525,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .consisid import ConsisIDPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
@@ -574,6 +586,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .hidream_image import HiDreamImagePipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
@@ -651,7 +664,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .sana import SanaPipeline, SanaSprintPipeline
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
@@ -761,6 +774,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KolorsPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_opencv_objects import *
else:
from .consisid import ConsisIDPipeline
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
@@ -12,17 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -112,7 +120,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class AuraFlowPipeline(DiffusionPipeline):
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
@@ -233,6 +241,7 @@ class AuraFlowPipeline(DiffusionPipeline):
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -259,10 +268,20 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -346,6 +365,11 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_embeds = None
negative_prompt_attention_mask = None
if self.text_encoder is not None:
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -403,6 +427,10 @@ class AuraFlowPipeline(DiffusionPipeline):
def guidance_scale(self):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@@ -428,6 +456,7 @@ class AuraFlowPipeline(DiffusionPipeline):
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -486,6 +515,10 @@ class AuraFlowPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -520,6 +553,7 @@ class AuraFlowPipeline(DiffusionPipeline):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
@@ -530,6 +564,7 @@ class AuraFlowPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -553,6 +588,7 @@ class AuraFlowPipeline(DiffusionPipeline):
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -594,6 +630,7 @@ class AuraFlowPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
# perform guidance

Some files were not shown because too many files have changed in this diff Show More