Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b365801c57 | |||
| 644147a198 | |||
| c852f239f2 | |||
| be861e236f | |||
| 2d744f0707 | |||
| 41c7e72d44 |
@@ -28,9 +28,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio\
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
onnxruntime \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
|
||||
+33
-43
@@ -175,7 +175,7 @@
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/quanto
|
||||
- local: quantization/quanto
|
||||
title: quanto
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
@@ -265,23 +265,19 @@
|
||||
sections:
|
||||
- local: api/models/overview
|
||||
title: Overview
|
||||
- local: api/models/auto_model
|
||||
title: AutoModel
|
||||
- sections:
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_union
|
||||
title: ControlNetUnionModel
|
||||
- local: api/models/controlnet_flux
|
||||
title: FluxControlNetModel
|
||||
- local: api/models/controlnet_hunyuandit
|
||||
title: HunyuanDiT2DControlNetModel
|
||||
- local: api/models/controlnet_sana
|
||||
title: SanaControlNetModel
|
||||
- local: api/models/controlnet_sd3
|
||||
title: SD3ControlNetModel
|
||||
- local: api/models/controlnet_sparsectrl
|
||||
title: SparseControlNetModel
|
||||
- local: api/models/controlnet_union
|
||||
title: ControlNetUnionModel
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/allegro_transformer3d
|
||||
@@ -290,32 +286,30 @@
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/consisid_transformer3d
|
||||
title: ConsisIDTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
title: CogView3PlusTransformer2DModel
|
||||
- local: api/models/cogview4_transformer2d
|
||||
title: CogView4Transformer2DModel
|
||||
- local: api/models/consisid_transformer3d
|
||||
title: ConsisIDTransformer3DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
title: Lumina2Transformer2DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
title: Lumina2Transformer2DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/mochi_transformer3d
|
||||
title: MochiTransformer3DModel
|
||||
- local: api/models/omnigen_transformer
|
||||
@@ -324,10 +318,10 @@
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/stable_audio_transformer
|
||||
title: StableAudioDiTModel
|
||||
- local: api/models/transformer2d
|
||||
@@ -342,10 +336,10 @@
|
||||
title: StableCascadeUNet
|
||||
- local: api/models/unet
|
||||
title: UNet1DModel
|
||||
- local: api/models/unet2d-cond
|
||||
title: UNet2DConditionModel
|
||||
- local: api/models/unet2d
|
||||
title: UNet2DModel
|
||||
- local: api/models/unet2d-cond
|
||||
title: UNet2DConditionModel
|
||||
- local: api/models/unet3d-cond
|
||||
title: UNet3DConditionModel
|
||||
- local: api/models/unet-motion
|
||||
@@ -354,10 +348,6 @@
|
||||
title: UViT2DModel
|
||||
title: UNets
|
||||
- sections:
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
title: AutoencoderDC
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/autoencoderkl_allegro
|
||||
@@ -374,6 +364,10 @@
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
title: AutoencoderDC
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
@@ -426,8 +420,6 @@
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_sana
|
||||
title: ControlNet-Sana
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
@@ -452,8 +444,6 @@
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
@@ -521,40 +511,40 @@
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
- local: api/pipelines/stable_diffusion/img2img
|
||||
title: Image-to-image
|
||||
- local: api/pipelines/stable_diffusion/svd
|
||||
title: Image-to-video
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
title: Stable Diffusion
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
|
||||
@@ -20,13 +20,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
|
||||
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
|
||||
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
|
||||
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
|
||||
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
|
||||
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
|
||||
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
|
||||
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
|
||||
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
@@ -59,9 +56,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
## Mochi1LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
## AuraFlowLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
|
||||
|
||||
## LTXVideoLoraLoaderMixin
|
||||
|
||||
@@ -79,14 +73,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoModel
|
||||
|
||||
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel, AutoPipelineForText2Image
|
||||
|
||||
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
|
||||
```
|
||||
|
||||
|
||||
## AutoModel
|
||||
|
||||
[[autodoc]] AutoModel
|
||||
- all
|
||||
- from_pretrained
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLAllegro
|
||||
|
||||
vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLAllegro
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# SanaControlNetModel
|
||||
|
||||
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
|
||||
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
|
||||
|
||||
## SanaControlNetModel
|
||||
[[autodoc]] SanaControlNetModel
|
||||
|
||||
## SanaControlNetOutput
|
||||
[[autodoc]] models.controlnets.controlnet_sana.SanaControlNetOutput
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HiDreamImageTransformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai).
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HiDreamImageTransformer2DModel
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HiDreamImageTransformer2DModel
|
||||
|
||||
[[autodoc]] HiDreamImageTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -89,23 +89,6 @@ image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
## Support for `torch.compile()`
|
||||
|
||||
AuraFlow can be compiled with `torch.compile()` to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from [here](https://pytorch.org/). The snippet below shows the changes needed to enable this:
|
||||
|
||||
```diff
|
||||
+ torch.fx.experimental._config.use_duck_shape = False
|
||||
+ pipeline.transformer = torch.compile(
|
||||
pipeline.transformer, fullgraph=True, dynamic=True
|
||||
)
|
||||
```
|
||||
|
||||
Specifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
|
||||
This enables from 100% (on low resolutions) to a 30% (on 1536x1536 resolution) speed improvements.
|
||||
|
||||
Thanks to [AstraliteHeart](https://github.com/huggingface/diffusers/pull/11297/) who helped us rewrite the [`AuraFlowTransformer2DModel`] class so that the above works for different resolutions ([PR](https://github.com/huggingface/diffusers/pull/11297/)).
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||
|
||||
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
|
||||
The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
|
||||
|
||||
## SanaControlNetPipeline
|
||||
[[autodoc]] SanaControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## SanaPipelineOutput
|
||||
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
|
||||
@@ -1,43 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# HiDreamImage
|
||||
|
||||
[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Available models
|
||||
|
||||
The following models are available for the [`HiDreamImagePipeline`](text-to-image) pipeline:
|
||||
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - |
|
||||
| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - |
|
||||
| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - |
|
||||
|
||||
## HiDreamImagePipeline
|
||||
|
||||
[[autodoc]] HiDreamImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HiDreamImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput
|
||||
@@ -83,8 +83,4 @@ Happy exploring, and thank you for being part of the Diffusers community!
|
||||
<td><a href="https://github.com/suzukimain/auto_diffusers"> Model Search </a></td>
|
||||
<td>Search models on Civitai and Hugging Face</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/beinsezii/skrample"> Skrample </a></td>
|
||||
<td>Fully modular scheduler functions with 1st class diffusers integration.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@@ -49,7 +49,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
from diffusers import AutoModel
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
|
||||
@@ -63,7 +63,7 @@ text_encoder_2_8bit = T5EncoderModel.from_pretrained(
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
|
||||
|
||||
transformer_8bit = AutoModel.from_pretrained(
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
@@ -74,7 +74,7 @@ transformer_8bit = AutoModel.from_pretrained(
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
|
||||
|
||||
```diff
|
||||
transformer_8bit = AutoModel.from_pretrained(
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
@@ -133,7 +133,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
from diffusers import AutoModel
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,)
|
||||
@@ -147,7 +147,7 @@ text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,)
|
||||
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
@@ -158,7 +158,7 @@ transformer_4bit = AutoModel.from_pretrained(
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
|
||||
|
||||
```diff
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
@@ -217,11 +217,11 @@ print(model.get_memory_footprint())
|
||||
Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel, BitsAndBytesConfig
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model_4bit = AutoModel.from_pretrained(
|
||||
model_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
|
||||
)
|
||||
```
|
||||
@@ -243,13 +243,13 @@ An "outlier" is a hidden state value greater than a certain threshold, and these
|
||||
To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel, BitsAndBytesConfig
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True, llm_int8_threshold=10,
|
||||
)
|
||||
|
||||
model_8bit = AutoModel.from_pretrained(
|
||||
model_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
@@ -305,7 +305,7 @@ NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
from diffusers import AutoModel
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
@@ -325,7 +325,7 @@ quant_config = DiffusersBitsAndBytesConfig(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
@@ -343,7 +343,7 @@ Nested quantization is a technique that can save additional memory at no additio
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
from diffusers import AutoModel
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
@@ -363,7 +363,7 @@ quant_config = DiffusersBitsAndBytesConfig(
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
@@ -379,7 +379,7 @@ Once quantized, you can dequantize a model to its original precision, but this m
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
from diffusers import AutoModel
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
@@ -399,7 +399,7 @@ quant_config = DiffusersBitsAndBytesConfig(
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
|
||||
@@ -26,13 +26,13 @@ The example below only quantizes the weights to int8.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = AutoModel.from_pretrained(
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
@@ -99,10 +99,10 @@ To serialize a quantized model in a given dtype, first load the model with the d
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, TorchAoConfig
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = AutoModel.from_pretrained(
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
@@ -115,9 +115,9 @@ To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] me
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, AutoModel
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel
|
||||
|
||||
transformer = AutoModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
|
||||
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -131,10 +131,10 @@ If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, c
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = AutoModel.from_pretrained(
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
@@ -146,13 +146,10 @@ transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, m
|
||||
# Load the model
|
||||
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
|
||||
with init_empty_weights():
|
||||
transformer = AutoModel.from_config("/path/to/flux_uint4wo/config.json")
|
||||
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The [`AutoModel`] API is supported for PyTorch >= 2.6 as shown in the examples below.
|
||||
|
||||
## Resources
|
||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
|
||||
|
||||
@@ -163,9 +163,6 @@ Models are initiated with the [`~ModelMixin.from_pretrained`] method which also
|
||||
>>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Use the [`AutoModel`] API to automatically select a model class if you're unsure of which one to use.
|
||||
|
||||
To access the model parameters, call `model.config`:
|
||||
|
||||
```py
|
||||
|
||||
@@ -31,10 +31,10 @@ To adapt your text-to-image model for inpainting, you'll need to change the numb
|
||||
Initialize a [`UNet2DConditionModel`] with the pretrained text-to-image model weights, and change `in_channels` to 9. Changing the number of `in_channels` means you need to set `ignore_mismatched_sizes=True` and `low_cpu_mem_usage=False` to avoid a size mismatch error because the shape is different now.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
unet = AutoModel.from_pretrained(
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
in_channels=9,
|
||||
|
||||
@@ -165,10 +165,10 @@ flush()
|
||||
Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
from diffusers import FluxTransformer2DModel
|
||||
import torch
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
device_map="auto",
|
||||
|
||||
@@ -32,9 +32,9 @@ The denoiser checkpoint can also have multiple shards and supports inference tha
|
||||
For example, let's save a sharded checkpoint for the [SDXL UNet](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/unet):
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
|
||||
)
|
||||
unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
|
||||
@@ -43,10 +43,10 @@ unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
|
||||
The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]:
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel, StableDiffusionXLPipeline
|
||||
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
|
||||
@@ -134,7 +134,7 @@ The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads L
|
||||
- the LoRA weights don't have separate identifiers for the UNet and text encoder
|
||||
- the LoRA weights have separate identifiers for the UNet and text encoder
|
||||
|
||||
To directly load (and save) a LoRA adapter at the *model-level*, use [`~loaders.PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`~loaders.PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
|
||||
To directly load (and save) a LoRA adapter at the *model-level*, use [`~PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
|
||||
|
||||
Use the `weight_name` parameter to specify the specific weight file and the `prefix` parameter to filter for the appropriate state dicts (`"unet"` in this case) to load.
|
||||
|
||||
@@ -155,7 +155,7 @@ image
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
|
||||
</div>
|
||||
|
||||
Save an adapter with [`~loaders.PeftAdapterMixin.save_lora_adapter`].
|
||||
Save an adapter with [`~PeftAdapterMixin.save_lora_adapter`].
|
||||
|
||||
To unload the LoRA weights, use the [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
|
||||
|
||||
|
||||
@@ -66,10 +66,10 @@ Let's dive deeper into what these steps entail.
|
||||
1. Load a UNet that corresponds to the UNet in the LoRA checkpoint. In this case, both LoRAs use the SDXL UNet as their base model.
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
from diffusers import UNet2DConditionModel
|
||||
import torch
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
@@ -136,7 +136,7 @@ feng_peft_model.load_state_dict(original_state_dict, strict=True)
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
base_unet = AutoModel.from_pretrained(
|
||||
base_unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel
|
||||
from diffusers.models.controlnet_flux import FluxControlNetModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
Lumina2Transformer2DModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -898,7 +898,7 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
pipeline = Lumina2Pipeline.from_pretrained(
|
||||
pipeline = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
revision=args.revision,
|
||||
@@ -990,7 +990,7 @@ def main(args):
|
||||
text_encoder.to(dtype=torch.bfloat16)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = Lumina2Pipeline.from_pretrained(
|
||||
text_encoding_pipeline = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
@@ -1034,7 +1034,7 @@ def main(args):
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
Lumina2Pipeline.save_lora_weights(
|
||||
Lumina2Text2ImgPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
)
|
||||
@@ -1050,7 +1050,7 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict = Lumina2Pipeline.lora_state_dict(input_dir)
|
||||
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
@@ -1473,7 +1473,7 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
# create pipeline
|
||||
pipeline = Lumina2Pipeline.from_pretrained(
|
||||
pipeline = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
transformer=accelerator.unwrap_model(transformer),
|
||||
revision=args.revision,
|
||||
@@ -1503,14 +1503,14 @@ def main(args):
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
Lumina2Pipeline.save_lora_weights(
|
||||
Lumina2Text2ImgPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = Lumina2Pipeline.from_pretrained(
|
||||
pipeline = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
|
||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.33.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,216 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import (
|
||||
SanaControlNetModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
|
||||
def main(args):
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
all_state_dict = torch.load(file_path, weights_only=True)
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
|
||||
|
||||
# ControlNet Input Projection.
|
||||
converted_state_dict["input_block.weight"] = state_dict.pop("controlnet.0.before_proj.weight")
|
||||
converted_state_dict["input_block.bias"] = state_dict.pop("controlnet.0.before_proj.bias")
|
||||
|
||||
for depth in range(7):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.scale_shift_table"
|
||||
)
|
||||
|
||||
# Linear Attention is all you need 🤘
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.mlp.point_conv.conv.weight"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(
|
||||
state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"controlnet.{depth}.copied_block.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# ControlNet After Projection
|
||||
converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(
|
||||
f"controlnet.{depth}.after_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"controlnet_blocks.{depth}.bias"] = state_dict.pop(f"controlnet.{depth}.after_proj.bias")
|
||||
|
||||
# ControlNet
|
||||
with CTX():
|
||||
controlnet = SanaControlNetModel(
|
||||
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
|
||||
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
|
||||
num_layers=model_kwargs[args.model_type]["num_layers"],
|
||||
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
caption_channels=2304,
|
||||
sample_size=args.image_size // 32,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(controlnet, converted_state_dict)
|
||||
else:
|
||||
controlnet.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
num_model_params = sum(p.numel() for p in controlnet.parameters())
|
||||
print(f"Total number of controlnet parameters: {num_model_params}")
|
||||
|
||||
controlnet = controlnet.to(weight_dtype)
|
||||
controlnet.load_state_dict(converted_state_dict, strict=True)
|
||||
|
||||
print(f"Saving Sana ControlNet in Diffusers format in {args.dump_path}.")
|
||||
controlnet.save_pretrained(args.dump_path)
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024, 2048, 4096],
|
||||
required=False,
|
||||
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="SanaMS_1600M_P1_ControlNet_D7",
|
||||
type=str,
|
||||
choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"],
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_kwargs = {
|
||||
"SanaMS_1600M_P1_ControlNet_D7": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 7,
|
||||
},
|
||||
"SanaMS_600M_P1_ControlNet_D7": {
|
||||
"num_attention_heads": 36,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 16,
|
||||
"cross_attention_head_dim": 72,
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 7,
|
||||
},
|
||||
}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
@@ -53,12 +53,7 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [
|
||||
key
|
||||
for key in down_blocks[i]
|
||||
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
|
||||
]
|
||||
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
@@ -72,10 +67,6 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
@@ -94,11 +85,8 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key
|
||||
for key in up_blocks[block_id]
|
||||
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
@@ -112,10 +100,6 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
|
||||
@@ -142,7 +142,6 @@ _deps = [
|
||||
"urllib3<=2.0.0",
|
||||
"black",
|
||||
"phonemizer",
|
||||
"opencv-python",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -269,7 +268,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.34.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.33.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.34.0.dev0"
|
||||
__version__ = "0.33.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -14,7 +14,6 @@ from .utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
@@ -171,7 +170,6 @@ else:
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
@@ -190,7 +188,6 @@ else:
|
||||
"OmniGenTransformer2DModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
@@ -268,7 +265,6 @@ else:
|
||||
"EulerDiscreteScheduler",
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
"FlowMatchHeunDiscreteScheduler",
|
||||
"FlowMatchLCMScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
@@ -356,6 +352,7 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
@@ -371,7 +368,6 @@ else:
|
||||
"FluxInpaintPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
@@ -430,7 +426,6 @@ else:
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintPipeline",
|
||||
@@ -523,19 +518,6 @@ else:
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_opencv_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_opencv_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["ConsisIDPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -766,7 +748,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
@@ -785,7 +766,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
@@ -861,7 +841,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EulerDiscreteScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FlowMatchHeunDiscreteScheduler,
|
||||
FlowMatchLCMScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
@@ -930,6 +909,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
@@ -945,7 +925,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxInpaintPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
@@ -1004,7 +983,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaControlNetPipeline,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintPipeline,
|
||||
@@ -1110,15 +1088,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_opencv_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import ConsisIDPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -49,5 +49,4 @@ deps = {
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"black": "black",
|
||||
"phonemizer": "phonemizer",
|
||||
"opencv-python": "opencv-python",
|
||||
}
|
||||
|
||||
@@ -65,7 +65,6 @@ if is_torch_available():
|
||||
"AmusedLoraLoaderMixin",
|
||||
"StableDiffusionLoraLoaderMixin",
|
||||
"SD3LoraLoaderMixin",
|
||||
"AuraFlowLoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
"LTXVideoLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
@@ -104,7 +103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
|
||||
@@ -526,7 +526,7 @@ class FluxIPAdapterMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=image_encoder_dtype,
|
||||
dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
|
||||
@@ -33,24 +33,6 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
||||
# 1. get all state_dict_keys
|
||||
all_keys = list(state_dict.keys())
|
||||
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
|
||||
not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
|
||||
|
||||
# check if state_dict contains both patterns
|
||||
contains_sgm_patterns = False
|
||||
contains_not_sgm_patterns = False
|
||||
for key in all_keys:
|
||||
if any(p in key for p in sgm_patterns):
|
||||
contains_sgm_patterns = True
|
||||
elif any(p in key for p in not_sgm_patterns):
|
||||
contains_not_sgm_patterns = True
|
||||
|
||||
# if state_dict contains both patterns, remove sgm
|
||||
# we can then return state_dict immediately
|
||||
if contains_sgm_patterns and contains_not_sgm_patterns:
|
||||
for key in all_keys:
|
||||
if any(p in key for p in sgm_patterns):
|
||||
state_dict.pop(key)
|
||||
return state_dict
|
||||
|
||||
# 2. check if needs remapping, if not return original dict
|
||||
is_in_sgm_format = False
|
||||
@@ -144,7 +126,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
if state_dict:
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError("At this point all state dict entries have to be converted.")
|
||||
|
||||
return new_state_dict
|
||||
@@ -1626,64 +1608,3 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
||||
# https://github.com/kohya-ss/musubi-tuner
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
|
||||
def get_alpha_scales(down_weight, key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = original_state_dict.pop(key + ".alpha").item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for i in range(num_blocks):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,7 +52,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
@@ -261,11 +260,6 @@ class FromOriginalModelMixin:
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if quantization_config is not None:
|
||||
user_agent["quant"] = quantization_config.quant_method.value
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
@@ -284,7 +278,6 @@ class FromOriginalModelMixin:
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
|
||||
@@ -177,7 +177,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
@@ -405,16 +404,13 @@ def load_single_file_checkpoint(
|
||||
local_files_only=None,
|
||||
revision=None,
|
||||
disable_mmap=False,
|
||||
user_agent=None,
|
||||
):
|
||||
if user_agent is None:
|
||||
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
||||
|
||||
if os.path.isfile(pretrained_model_link_or_path):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path
|
||||
|
||||
else:
|
||||
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
||||
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
||||
pretrained_model_link_or_path = _get_model_file(
|
||||
repo_id,
|
||||
weights_name=weights_name,
|
||||
@@ -642,9 +638,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
||||
model_type = "ltx-video-0.9.5"
|
||||
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
@@ -2409,41 +2403,13 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
}
|
||||
|
||||
if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
|
||||
@@ -49,7 +49,6 @@ if is_torch_available():
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
]
|
||||
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
||||
@@ -77,7 +76,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
@@ -135,7 +133,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
MultiControlNetModel,
|
||||
MultiControlNetUnionModel,
|
||||
SanaControlNetModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SparseControlNetModel,
|
||||
@@ -154,7 +151,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
|
||||
@@ -829,7 +829,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
|
||||
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
||||
|
||||
@@ -1285,7 +1285,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
|
||||
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
||||
|
||||
@@ -887,7 +887,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
@@ -909,7 +909,7 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
@@ -255,7 +255,7 @@ class Decoder(nn.Module):
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
prev_output_channel=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
|
||||
@@ -9,7 +9,6 @@ if is_torch_available():
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
)
|
||||
from .controlnet_sana import SanaControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
||||
from .controlnet_sparsectrl import (
|
||||
SparseControlNetConditioningEmbedding,
|
||||
|
||||
@@ -298,6 +298,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
@@ -311,15 +320,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
@@ -430,7 +430,7 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
||||
@@ -454,18 +454,17 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from ..transformers.sana_transformer import SanaTransformerBlock
|
||||
from .controlnet import zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: Optional[int] = 32,
|
||||
num_attention_heads: int = 70,
|
||||
attention_head_dim: int = 32,
|
||||
num_layers: int = 7,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
caption_channels: int = 2304,
|
||||
mlp_ratio: float = 2.5,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
pos_embed_type="sincos" if interpolation_scale is not None else None,
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
|
||||
self.input_block = zero_module(nn.Linear(inner_dim, inner_dim))
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = nn.Linear(inner_dim, inner_dim)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
post_patch_height, post_patch_width = height // p, width // p
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
|
||||
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks
|
||||
block_res_samples = ()
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
# 3. ControlNet blocks
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_res_samples,)
|
||||
|
||||
return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
||||
@@ -38,7 +38,7 @@ if is_transformers_available():
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def text_encoder_attn_modules(text_encoder: nn.Module):
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder: nn.Module):
|
||||
return attn_modules
|
||||
|
||||
|
||||
def text_encoder_mlp_modules(text_encoder: nn.Module):
|
||||
def text_encoder_mlp_modules(text_encoder):
|
||||
mlp_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
|
||||
@@ -21,7 +21,6 @@ if is_torch_available():
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
|
||||
@@ -13,15 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
@@ -74,23 +74,15 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
|
||||
# because original input are in flattened format, we have to flatten this 2d grid as well.
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
|
||||
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
|
||||
|
||||
# Calculate the top-left corner indices for the centered patch grid
|
||||
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
|
||||
starth = h_max // 2 - h_p // 2
|
||||
endh = starth + h_p
|
||||
startw = w_max // 2 - w_p // 2
|
||||
|
||||
# Generate the row and column indices for the desired patch grid
|
||||
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
|
||||
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
|
||||
|
||||
# Create a 2D grid of indices
|
||||
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
|
||||
|
||||
# Convert the 2D grid indices to flattened 1D indices
|
||||
selected_indices = (row_indices * w_max + col_indices).flatten()
|
||||
|
||||
return selected_indices
|
||||
endw = startw + w_p
|
||||
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
|
||||
return original_pe_indexes.flatten()
|
||||
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
@@ -168,20 +160,14 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
||||
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
# Norm + Projection.
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
# Attention.
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
||||
@@ -237,15 +223,10 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
):
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
# Norm + Projection.
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
@@ -255,9 +236,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**attention_kwargs,
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -275,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
||||
|
||||
@@ -283,17 +262,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
|
||||
num_single_dit_layers (`int`, *optional*, defaults to 32):
|
||||
num_single_dit_layers (`int`, *optional*, defaults to 4):
|
||||
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
|
||||
representations.
|
||||
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
out_channels (`int`, defaults to 4): Number of output channels.
|
||||
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
||||
"""
|
||||
|
||||
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
||||
@@ -470,24 +449,8 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
||||
@@ -511,10 +474,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
||||
@@ -531,9 +491,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
)
|
||||
|
||||
else:
|
||||
combined_hidden_states = block(
|
||||
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
|
||||
)
|
||||
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
|
||||
|
||||
hidden_states = combined_hidden_states[:, encoder_seq_len:]
|
||||
|
||||
@@ -554,10 +512,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -483,7 +483,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
@@ -547,7 +546,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
# 2. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
@@ -558,11 +557,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -572,8 +569,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
|
||||
@@ -1,922 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HiDreamImageFeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class HiDreamImagePooledEmbed(nn.Module):
|
||||
def __init__(self, text_emb_dim, hidden_size):
|
||||
super().__init__()
|
||||
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
|
||||
|
||||
def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor:
|
||||
return self.pooled_embedder(pooled_embed)
|
||||
|
||||
|
||||
class HiDreamImageTimestepEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
|
||||
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
||||
t_emb = self.timestep_embedder(t_emb)
|
||||
return t_emb
|
||||
|
||||
|
||||
class HiDreamImageOutEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1)
|
||||
hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HiDreamImagePatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent)
|
||||
return latent
|
||||
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
is_mps = pos.device.type == "mps"
|
||||
is_npu = pos.device.type == "npu"
|
||||
|
||||
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.float()
|
||||
|
||||
|
||||
class HiDreamImageEmbedND(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(2)
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HiDreamAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
scale_qk: bool = True,
|
||||
eps: float = 1e-5,
|
||||
processor=None,
|
||||
out_dim: int = None,
|
||||
single: bool = False,
|
||||
):
|
||||
super(Attention, self).__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.sliceable_head_dim = heads
|
||||
self.single = single
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim)
|
||||
self.to_k = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
self.to_v = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
self.to_out = nn.Linear(self.inner_dim, self.out_dim)
|
||||
self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
|
||||
self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
|
||||
|
||||
if not single:
|
||||
self.to_q_t = nn.Linear(query_dim, self.inner_dim)
|
||||
self.to_k_t = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
self.to_v_t = nn.Linear(self.inner_dim, self.inner_dim)
|
||||
self.to_out_t = nn.Linear(self.inner_dim, self.out_dim)
|
||||
self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
|
||||
self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
norm_hidden_states: torch.Tensor,
|
||||
hidden_states_masks: torch.Tensor = None,
|
||||
norm_encoder_hidden_states: torch.Tensor = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states=norm_hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
|
||||
class HiDreamAttnProcessor:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: HiDreamAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
dtype = hidden_states.dtype
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
|
||||
key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
|
||||
value_i = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key_i.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
|
||||
if hidden_states_masks is not None:
|
||||
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
|
||||
|
||||
if not attn.single:
|
||||
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
|
||||
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
|
||||
value_t = attn.to_v_t(encoder_hidden_states)
|
||||
|
||||
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
num_image_tokens = query_i.shape[1]
|
||||
num_text_tokens = query_t.shape[1]
|
||||
query = torch.cat([query_i, query_t], dim=1)
|
||||
key = torch.cat([key_i, key_t], dim=1)
|
||||
value = torch.cat([value_i, value_t], dim=1)
|
||||
else:
|
||||
query = query_i
|
||||
key = key_i
|
||||
value = value_i
|
||||
|
||||
if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
|
||||
else:
|
||||
query_1, query_2 = query.chunk(2, dim=-1)
|
||||
key_1, key_2 = key.chunk(2, dim=-1)
|
||||
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
hidden_states_i = attn.to_out(hidden_states_i)
|
||||
hidden_states_t = attn.to_out_t(hidden_states_t)
|
||||
return hidden_states_i, hidden_states_t
|
||||
else:
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
|
||||
super().__init__()
|
||||
self.top_k = num_activated_experts
|
||||
self.n_routed_experts = num_routed_experts
|
||||
|
||||
self.scoring_func = "softmax"
|
||||
self.alpha = aux_loss_alpha
|
||||
self.seq_aux = False
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = False
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
# print(bsz, seq_len, h)
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == "softmax":
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(
|
||||
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
|
||||
).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MOEFeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
num_routed_experts: int,
|
||||
num_activated_experts: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
|
||||
self.experts = nn.ModuleList(
|
||||
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
|
||||
)
|
||||
self.gate = MoEGate(
|
||||
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
|
||||
)
|
||||
self.num_activated_experts = num_activated_experts
|
||||
|
||||
def forward(self, x):
|
||||
wtype = x.dtype
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape).to(dtype=wtype)
|
||||
# y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.num_activated_experts
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# for fp16 and other dtype
|
||||
expert_cache = expert_cache.to(expert_out.dtype)
|
||||
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
|
||||
return expert_cache
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=HiDreamAttnProcessor(),
|
||||
single=True,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
)
|
||||
else:
|
||||
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
wtype = hidden_states.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[
|
||||
:, None
|
||||
].chunk(6, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
||||
attn_output_i = self.attn1(
|
||||
norm_hidden_states,
|
||||
hidden_states_masks,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
|
||||
hidden_states = ff_output_i + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HiDreamImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True))
|
||||
|
||||
# 1. Attention
|
||||
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
|
||||
self.norm1_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
|
||||
self.attn1 = HiDreamAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=HiDreamAttnProcessor(),
|
||||
single=False,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
|
||||
if num_routed_experts > 0:
|
||||
self.ff_i = MOEFeedForwardSwiGLU(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
)
|
||||
else:
|
||||
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
self.norm3_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
|
||||
self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
wtype = hidden_states.dtype
|
||||
(
|
||||
shift_msa_i,
|
||||
scale_msa_i,
|
||||
gate_msa_i,
|
||||
shift_mlp_i,
|
||||
scale_mlp_i,
|
||||
gate_mlp_i,
|
||||
shift_msa_t,
|
||||
scale_msa_t,
|
||||
gate_msa_t,
|
||||
shift_mlp_t,
|
||||
scale_mlp_t,
|
||||
gate_mlp_t,
|
||||
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
|
||||
|
||||
# 1. MM-Attention
|
||||
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
||||
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
|
||||
|
||||
attn_output_i, attn_output_t = self.attn1(
|
||||
norm_hidden_states,
|
||||
hidden_states_masks,
|
||||
norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
||||
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
||||
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
|
||||
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
|
||||
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
|
||||
hidden_states = ff_output_i + hidden_states
|
||||
encoder_hidden_states = ff_output_t + encoder_hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HiDreamBlock(nn.Module):
|
||||
def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
return self.block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Optional[int] = None,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 16,
|
||||
num_single_layers: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 20,
|
||||
caption_channels: List[int] = None,
|
||||
text_emb_dim: int = 2048,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||
max_resolution: Tuple[int, int] = (128, 128),
|
||||
llama_layers: List[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
|
||||
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
|
||||
self.x_embedder = HiDreamImagePatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.inner_dim,
|
||||
)
|
||||
self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.double_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamBlock(
|
||||
HiDreamImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
)
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_stream_blocks = nn.ModuleList(
|
||||
[
|
||||
HiDreamBlock(
|
||||
HiDreamImageSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
)
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
|
||||
|
||||
caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
|
||||
caption_projection = []
|
||||
for caption_channel in caption_channels:
|
||||
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
if is_training:
|
||||
B, S, F = x.shape
|
||||
C = F // (self.config.patch_size * self.config.patch_size)
|
||||
x = (
|
||||
x.reshape(B, S, self.config.patch_size, self.config.patch_size, C)
|
||||
.permute(0, 4, 1, 2, 3)
|
||||
.reshape(B, C, S, self.config.patch_size * self.config.patch_size)
|
||||
)
|
||||
else:
|
||||
x_arr = []
|
||||
p1 = self.config.patch_size
|
||||
p2 = self.config.patch_size
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
t = x[i, : pH * pW].reshape(1, pH, pW, -1)
|
||||
F_token = t.shape[-1]
|
||||
C = F_token // (p1 * p2)
|
||||
t = t.reshape(1, pH, pW, p1, p2, C)
|
||||
t = t.permute(0, 5, 1, 3, 2, 4)
|
||||
t = t.reshape(1, C, pH * p1, pW * p2)
|
||||
x_arr.append(t)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
patch_size = self.config.patch_size
|
||||
patch_height, patch_width = height // patch_size, width // patch_size
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# create img_sizes
|
||||
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
|
||||
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# create hidden_states_masks
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
|
||||
hidden_states_masks[:, : patch_height * patch_width] = 1.0
|
||||
else:
|
||||
hidden_states_masks = None
|
||||
|
||||
# create img_ids
|
||||
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
|
||||
row_indices = torch.arange(patch_height, device=device)[:, None]
|
||||
col_indices = torch.arange(patch_width, device=device)[None, :]
|
||||
img_ids[..., 1] = img_ids[..., 1] + row_indices
|
||||
img_ids[..., 2] = img_ids[..., 2] + col_indices
|
||||
img_ids = img_ids.reshape(patch_height * patch_width, -1)
|
||||
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
# Handle non-square latents
|
||||
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
|
||||
img_ids_pad[: patch_height * patch_width, :] = img_ids
|
||||
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
else:
|
||||
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
# patchify hidden_states
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
# Handle non-square latents
|
||||
out = torch.zeros(
|
||||
(batch_size, channels, self.max_seq, patch_size * patch_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height, patch_size, patch_width, patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height * patch_width, patch_size * patch_size
|
||||
)
|
||||
out[:, :, 0 : patch_height * patch_width] = hidden_states
|
||||
hidden_states = out
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
||||
batch_size, self.max_seq, patch_size * patch_size * channels
|
||||
)
|
||||
|
||||
else:
|
||||
# Handle square latents
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height, patch_size, patch_width, patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, patch_height * patch_width, patch_size * patch_size * channels
|
||||
)
|
||||
|
||||
return hidden_states, hidden_states_masks, img_sizes, img_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timesteps: torch.LongTensor = None,
|
||||
encoder_hidden_states_t5: torch.Tensor = None,
|
||||
encoder_hidden_states_llama3: torch.Tensor = None,
|
||||
pooled_embeds: torch.Tensor = None,
|
||||
img_ids: Optional[torch.Tensor] = None,
|
||||
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
||||
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
||||
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
||||
|
||||
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
|
||||
deprecation_message = (
|
||||
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
||||
)
|
||||
deprecate("img_ids", "0.34.0", deprecation_message)
|
||||
|
||||
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
|
||||
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
|
||||
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
||||
raise ValueError(
|
||||
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
||||
)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# spatial forward
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
|
||||
# Patchify the input
|
||||
if hidden_states_masks is None:
|
||||
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
|
||||
|
||||
# Embed the hidden states
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# 0. time
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
temb = timesteps + p_embedder
|
||||
|
||||
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
||||
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
||||
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||
new_encoder_hidden_states.append(enc_hidden_state)
|
||||
encoder_hidden_states = new_encoder_hidden_states
|
||||
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(encoder_hidden_states_t5)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
batch_size,
|
||||
encoder_hidden_states[-1].shape[1]
|
||||
+ encoder_hidden_states[-2].shape[1]
|
||||
+ encoder_hidden_states[0].shape[1],
|
||||
3,
|
||||
device=img_ids.device,
|
||||
dtype=img_ids.dtype,
|
||||
)
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids)
|
||||
|
||||
# 2. Blocks
|
||||
block_id = 0
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_encoder_hidden_states = torch.cat(
|
||||
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
hidden_states_masks,
|
||||
cur_encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
else:
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
encoder_hidden_states=cur_encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
image_tokens_seq_len = hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||
hidden_states_seq_len = hidden_states.shape[1]
|
||||
if hidden_states_masks is not None:
|
||||
encoder_attention_mask_ones = torch.ones(
|
||||
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||
device=hidden_states_masks.device,
|
||||
dtype=hidden_states_masks.dtype,
|
||||
)
|
||||
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
|
||||
|
||||
for bid, block in enumerate(self.single_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
hidden_states_masks,
|
||||
None,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
encoder_hidden_states=None,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||
output = self.final_layer(hidden_states, temb)
|
||||
output = self.unpatchify(output, img_sizes, self.training)
|
||||
if hidden_states_masks is not None:
|
||||
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output, hidden_states_masks)
|
||||
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)
|
||||
@@ -10,7 +10,6 @@ from ..utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
@@ -156,6 +155,7 @@ else:
|
||||
]
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
@@ -221,7 +221,6 @@ else:
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
]
|
||||
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
|
||||
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
|
||||
_import_structure["hunyuan_video"] = [
|
||||
"HunyuanVideoPipeline",
|
||||
@@ -281,7 +280,7 @@ else:
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["pia"] = ["PIAPipeline"]
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
|
||||
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline", "SanaControlNetPipeline"]
|
||||
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_audio"] = [
|
||||
@@ -415,18 +414,6 @@ else:
|
||||
"KolorsImg2ImgPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import (
|
||||
dummy_torch_and_transformers_and_opencv_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
|
||||
else:
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -525,6 +512,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
|
||||
from .consisid import ConsisIDPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
@@ -586,7 +574,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_video import (
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
@@ -664,7 +651,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintPipeline
|
||||
from .sana import SanaPipeline, SanaSprintPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
@@ -774,14 +761,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KolorsPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_opencv_objects import *
|
||||
else:
|
||||
from .consisid import ConsisIDPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -12,25 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5Tokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import AuraFlowLoraLoaderMixin
|
||||
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
@@ -120,7 +112,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
class AuraFlowPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
@@ -241,7 +233,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
@@ -268,20 +259,10 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
@@ -365,11 +346,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
|
||||
@@ -427,10 +403,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -456,7 +428,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
max_sequence_length: int = 256,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
@@ -515,10 +486,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
@@ -553,7 +520,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
|
||||
# 2. Determine batch size.
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -564,7 +530,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -588,7 +553,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
@@ -630,7 +594,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -5,7 +5,6 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_opencv_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -16,12 +15,12 @@ _import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_opencv_available()):
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_consisid"] = ["ConsisIDPipeline"]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
@@ -28,16 +29,12 @@ from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDPMScheduler
|
||||
from ...utils import is_opencv_available, logging, replace_example_docstring
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import ConsisIDPipelineOutput
|
||||
|
||||
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["HiDreamImagePipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_hidream_image"] = ["HiDreamImagePipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_hidream_image import HiDreamImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiDreamImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for HiDreamImage pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -100,50 +100,6 @@ DEFAULT_PROMPT_TEMPLATE = {
|
||||
}
|
||||
|
||||
|
||||
def _expand_input_ids_with_image_tokens(
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
max_sequence_length,
|
||||
image_token_index,
|
||||
image_emb_len,
|
||||
image_emb_start,
|
||||
image_emb_end,
|
||||
pad_token_id,
|
||||
):
|
||||
special_image_token_mask = text_input_ids == image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
|
||||
|
||||
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
expanded_input_ids = torch.full(
|
||||
(text_input_ids.shape[0], max_expanded_length),
|
||||
pad_token_id,
|
||||
dtype=text_input_ids.dtype,
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
|
||||
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
|
||||
|
||||
expanded_attention_mask = torch.zeros(
|
||||
(text_input_ids.shape[0], max_expanded_length),
|
||||
dtype=prompt_attention_mask.dtype,
|
||||
device=prompt_attention_mask.device,
|
||||
)
|
||||
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
|
||||
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
|
||||
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
|
||||
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
|
||||
|
||||
return {
|
||||
"input_ids": expanded_input_ids,
|
||||
"attention_mask": expanded_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -295,12 +251,6 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt = [prompt_template["template"].format(p) for p in prompt]
|
||||
|
||||
crop_start = prompt_template.get("crop_start", None)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
image_emb_start = prompt_template.get("image_emb_start", 5)
|
||||
image_emb_end = prompt_template.get("image_emb_end", 581)
|
||||
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
||||
|
||||
if crop_start is None:
|
||||
prompt_template_input = self.tokenizer(
|
||||
prompt_template["template"],
|
||||
@@ -330,25 +280,19 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
|
||||
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
|
||||
|
||||
image_token_index = self.text_encoder.config.image_token_index
|
||||
pad_token_id = self.text_encoder.config.pad_token_id
|
||||
expanded_inputs = _expand_input_ids_with_image_tokens(
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
max_sequence_length,
|
||||
image_token_index,
|
||||
image_emb_len,
|
||||
image_emb_start,
|
||||
image_emb_end,
|
||||
pad_token_id,
|
||||
)
|
||||
prompt_embeds = self.text_encoder(
|
||||
**expanded_inputs,
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
pixel_values=image_embeds,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
image_emb_start = prompt_template.get("image_emb_start", 5)
|
||||
image_emb_end = prompt_template.get("image_emb_end", 581)
|
||||
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
text_crop_start = crop_start - 1 + image_emb_len
|
||||
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user