Compare commits

..

1 Commits

Author SHA1 Message Date
DN6 39115294b3 update 2025-10-21 11:51:18 +05:30
139 changed files with 740 additions and 12286 deletions
+1 -1
View File
@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
+8 -29
View File
@@ -42,39 +42,18 @@ jobs:
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
run: |
echo "$CHANGED_FILES"
ALLOWED_IMAGES=(
diffusers-pytorch-cpu
diffusers-pytorch-cuda
diffusers-pytorch-xformers-cuda
diffusers-pytorch-minimum-cuda
diffusers-doc-builder
)
declare -A IMAGES_TO_BUILD=()
for FILE in $CHANGED_FILES; do
for FILE in $CHANGED_FILES; do
# skip anything that isn't still on disk
if [[ ! -e "$FILE" ]]; then
if [[ ! -f "$FILE" ]]; then
echo "Skipping removed file $FILE"
continue
fi
if [[ "$FILE" == docker/*Dockerfile ]]; then
DOCKER_PATH="${FILE%/Dockerfile}"
DOCKER_TAG=$(basename "$DOCKER_PATH")
echo "Building Docker image for $DOCKER_TAG"
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
fi
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
IMAGES_TO_BUILD["$IMAGE"]=1
fi
done
done
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
echo "No relevant Docker changes detected."
exit 0
fi
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
DOCKER_PATH="docker/${IMAGE}"
echo "Building Docker image for $IMAGE"
docker build -t "$IMAGE" "$DOCKER_PATH"
done
if: steps.file_changes.outputs.all != ''
+1 -1
View File
@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600
+1 -1
View File
@@ -26,7 +26,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
+1 -1
View File
@@ -22,7 +22,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
+1 -1
View File
@@ -24,7 +24,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
+1 -1
View File
@@ -14,7 +14,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 50000
+1 -1
View File
@@ -18,7 +18,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
+1 -1
View File
@@ -8,7 +8,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
+1 -1
View File
@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_xet \
hf_transfer \
setuptools==69.5.1 \
bitsandbytes \
torchao \
+1 -1
View File
@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
scipy \
tensorboard \
transformers \
hf_xet
hf_transfer
CMD ["/bin/bash"]
+3 -2
View File
@@ -38,12 +38,13 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
datasets \
hf-doc-builder \
huggingface-hub \
hf_xet \
hf_transfer \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
transformers \
hf_transfer
CMD ["/bin/bash"]
+1 -1
View File
@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_xet
hf_transfer
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
+1 -1
View File
@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_xet
hf_transfer
CMD ["/bin/bash"]
@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_xet
hf_transfer
CMD ["/bin/bash"]
@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_xet \
hf_transfer \
xformers
CMD ["/bin/bash"]
-16
View File
@@ -323,8 +323,6 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
@@ -349,8 +347,6 @@
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuanimage_transformer_2d
title: HunyuanImageTransformer2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
@@ -415,10 +411,6 @@
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuanimage
title: AutoencoderKLHunyuanImage
- local: api/models/autoencoder_kl_hunyuanimage_refiner
title: AutoencoderKLHunyuanImageRefiner
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
@@ -471,8 +463,6 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -529,8 +519,6 @@
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kandinsky5
title: Kandinsky 5
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models
@@ -557,8 +545,6 @@
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
@@ -632,8 +618,6 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
@@ -1,32 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImage
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImage
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImage
[[autodoc]] AutoencoderKLHunyuanImage
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -1,32 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImageRefiner
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImageRefiner
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImageRefiner
[[autodoc]] AutoencoderKLHunyuanImageRefiner
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
## ChromaTransformer2DModel
@@ -1,30 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HunyuanImageTransformer2DModel
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanImageTransformer2DModel
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanImageTransformer2DModel
[[autodoc]] HunyuanImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -1,19 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaFiboTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
## BriaFiboTransformer2DModel
[[autodoc]] BriaFiboTransformer2DModel
-45
View File
@@ -1,45 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__
+6 -7
View File
@@ -19,21 +19,20 @@ specific language governing permissions and limitations under the License.
Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found here:
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
> [!TIP]
> Chroma can use all the same optimizations as Flux.
## Inference
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = [
@@ -64,10 +63,10 @@ Then run the following example
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
model_id = "lodestones/Chroma1-HD"
model_id = "lodestones/Chroma"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
@@ -1,152 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HunyuanImage2.1
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
HunyuanImage-2.1 comes in the following variants:
| model type | model id |
|:----------:|:--------:|
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
```
You can inspect the `guider` object:
```py
>>> pipe.guider
AdaptiveProjectedMixGuidance {
"_class_name": "AdaptiveProjectedMixGuidance",
"_diffusers_version": "0.36.0.dev0",
"adaptive_projected_guidance_momentum": -0.5,
"adaptive_projected_guidance_rescale": 10.0,
"adaptive_projected_guidance_scale": 10.0,
"adaptive_projected_guidance_start_step": 5,
"enabled": true,
"eta": 0.0,
"guidance_rescale": 0.0,
"guidance_scale": 3.5,
"start": 0.0,
"stop": 1.0,
"use_original_formulation": false
}
State:
step: None
num_inference_steps: None
timestep: None
count_prepared: 0
enabled: True
num_conditions: 2
momentum_buffer: None
is_apg_enabled: False
is_cfg_enabled: True
```
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
# Update the guider configuration
pipe.guider = pipe.guider.new(guidance_scale=5.0)
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
image = pipe(
prompt=prompt,
num_inference_steps=50,
height=2048,
width=2048,
).images[0]
image.save("image.png")
```
## HunyuanImage-2.1-Distilled
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
out = pipe(
prompt,
num_inference_steps=8,
distilled_guidance_scale=3.25,
height=2048,
width=2048,
generator=generator,
).images[0]
```
## HunyuanImagePipeline
[[autodoc]] HunyuanImagePipeline
- all
- __call__
## HunyuanImageRefinerPipeline
[[autodoc]] HunyuanImageRefinerPipeline
- all
- __call__
## HunyuanImagePipelineOutput
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
-149
View File
@@ -1,149 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
The model introduces several key innovations:
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
> [!TIP]
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
## Available Models
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
| model_id | Description | Use Cases |
|------------|-------------|-----------|
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
All models are available in 5-second and 10-second video generation versions.
## Kandinsky5T2VPipeline
[[autodoc]] Kandinsky5T2VPipeline
- all
- __call__
## Usage Examples
### Basic Text-to-Video Generation
```python
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### 10 second Models
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
```python
pipe = Kandinsky5T2VPipeline.from_pretrained(
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Set attention backend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
) # <--- Compile with max-autotune-no-cudagraphs
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=241,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
output = pipe(
prompt="A beautiful sunset over mountains",
num_inference_steps=16, # <--- Model is distilled in 16 steps
guidance_scale=1.0, # <--- no CFG
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
## Citation
```bibtex
@misc{kandinsky2025,
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
year = 2025
}
```
-131
View File
@@ -1,131 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# PRX
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("prx_output.png")
```
### Manual Component Loading
Load components individually to customize the pipeline for instance to use quantized models.
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
subfolder="vae",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
)
pipe.to("cuda")
```
## Memory Optimization
For memory-constrained environments:
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PRXPipeline
[[autodoc]] PRXPipeline
- all
- __call__
## PRXPipelineOutput
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
@@ -21,7 +21,6 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
@@ -140,7 +139,6 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
File diff suppressed because it is too large Load Diff
-345
View File
@@ -1,345 +0,0 @@
#!/usr/bin/env python3
"""
Script to convert PRX checkpoint from original codebase to diffusers format.
"""
import argparse
import json
import os
import sys
from dataclasses import asdict, dataclass
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PRXBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
num_heads: int = 28
depth: int = 16
axes_dim: Tuple[int, int] = (32, 32)
theta: int = 10_000
time_factor: float = 1000.0
time_max_period: int = 10_000
@dataclass(frozen=True)
class PRXFlux(PRXBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PRXDCAE(PRXBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PRXFlux()
elif vae_type == "dc-ae":
cfg = PRXDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
config_dict = asdict(cfg)
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
return config_dict
def create_parameter_mapping(depth: int) -> dict:
"""Create mapping from old parameter names to new diffusers names."""
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
# K norm for text tokens moved to attention module
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
# Attention output projection
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
return mapping
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
"""Convert old checkpoint parameters to new diffusers format."""
print("Converting checkpoint parameters...")
mapping = create_parameter_mapping(depth)
converted_state_dict = {}
for key, value in old_state_dict.items():
new_key = key
# Apply specific mappings if needed
if key in mapping:
new_key = mapping[key]
print(f" Mapped: {key} -> {new_key}")
converted_state_dict[new_key] = value
print(f"✓ Converted {len(converted_state_dict)} parameters")
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
# Load old checkpoint
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Handle different checkpoint formats
if isinstance(old_checkpoint, dict):
if "model" in old_checkpoint:
state_dict = old_checkpoint["model"]
elif "state_dict" in old_checkpoint:
state_dict = old_checkpoint["state_dict"]
else:
state_dict = old_checkpoint
else:
state_dict = old_checkpoint
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
# Convert parameter names if needed
model_depth = int(config.get("depth", 16))
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PRXTransformer2DModel...")
transformer = PRXTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
if missing_keys:
print(f"⚠ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠ Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print("✓ All parameters loaded successfully!")
return transformer
def create_scheduler_config(output_path: str, shift: float):
"""Create FlowMatchEulerDiscreteScheduler config."""
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
scheduler_path = os.path.join(output_path, "scheduler")
os.makedirs(scheduler_path, exist_ok=True)
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
json.dump(scheduler_config, f, indent=2)
print("✓ Created scheduler config")
def download_and_save_vae(vae_type: str, output_path: str):
"""Download and save VAE to local directory."""
from diffusers import AutoencoderDC, AutoencoderKL
vae_path = os.path.join(output_path, "vae")
os.makedirs(vae_path, exist_ok=True)
if vae_type == "flux":
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
else: # dc-ae
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
vae.save_pretrained(vae_path)
print(f"✓ Saved VAE to {vae_path}")
def download_and_save_text_encoder(output_path: str):
"""Download and save T5Gemma text encoder and tokenizer."""
from transformers import GemmaTokenizerFast
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
text_encoder_path = os.path.join(output_path, "text_encoder")
tokenizer_path = os.path.join(output_path, "tokenizer")
os.makedirs(text_encoder_path, exist_ok=True)
os.makedirs(tokenizer_path, exist_ok=True)
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
# Extract and save only the encoder
t5gemma_encoder = t5gemma_model.encoder
t5gemma_encoder.save_pretrained(text_encoder_path)
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
tokenizer.save_pretrained(tokenizer_path)
print(f"✓ Saved tokenizer to {tokenizer_path}")
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
"""Create model_index.json for the pipeline."""
if vae_type == "flux":
vae_class = "AutoencoderKL"
else: # dc-ae
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
model_index_path = os.path.join(output_path, "model_index.json")
with open(model_index_path, "w") as f:
json.dump(model_index, f, indent=2)
def main(args):
# Validate inputs
if not os.path.exists(args.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
config = build_config(args.vae_type)
# Create output directory
os.makedirs(args.output_path, exist_ok=True)
print(f"✓ Output directory: {args.output_path}")
# Create transformer from checkpoint
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
# Save transformer
transformer_path = os.path.join(args.output_path, "transformer")
os.makedirs(transformer_path, exist_ok=True)
# Save config
with open(os.path.join(transformer_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)
# Save model weights as safetensors
state_dict = transformer.state_dict()
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
print(f"✓ Saved transformer to {transformer_path}")
# Create scheduler config
create_scheduler_config(args.output_path, args.shift)
download_and_save_vae(args.vae_type, args.output_path)
download_and_save_text_encoder(args.output_path)
# Create model_index.json
create_model_index(args.vae_type, args.resolution, args.output_path)
# Verify the pipeline can be loaded
try:
pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
# Display model info
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
print(f"✓ Transformer parameters: {num_params:,}")
except Exception as e:
print(f"Pipeline verification failed: {e}")
return False
print("Conversion completed successfully!")
print(f"Converted pipeline saved to: {args.output_path}")
print(f"VAE type: {args.vae_type}")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
)
parser.add_argument(
"--vae_type",
type=str,
choices=["flux", "dc-ae"],
required=True,
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
)
parser.add_argument(
"--resolution",
type=int,
choices=[256, 512, 1024],
default=DEFAULT_RESOLUTION,
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
)
parser.add_argument(
"--shift",
type=float,
default=3.0,
help="Shift for the scheduler",
)
args = parser.parse_args()
try:
success = main(args)
if not success:
sys.exit(1)
except Exception as e:
print(f"Conversion failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
-22
View File
@@ -149,9 +149,7 @@ else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AdaptiveProjectedMixGuidance",
"AutoGuidance",
"BaseGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
@@ -186,8 +184,6 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
@@ -198,7 +194,6 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
@@ -221,7 +216,6 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
@@ -240,7 +234,6 @@ else:
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"PRXTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
@@ -431,7 +424,6 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
@@ -469,8 +461,6 @@ else:
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanImagePipeline",
"HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
@@ -529,7 +519,6 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"PRXPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
@@ -858,9 +847,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .guiders import (
AdaptiveProjectedGuidance,
AdaptiveProjectedMixGuidance,
AutoGuidance,
BaseGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
@@ -891,8 +878,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -903,7 +888,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
@@ -926,7 +910,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
@@ -945,7 +928,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
@@ -1106,7 +1088,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
@@ -1144,8 +1125,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanImagePipeline,
HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
@@ -1204,7 +1183,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PRXPipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
+13 -3
View File
@@ -14,18 +14,28 @@
from typing import Union
from ..utils import is_torch_available, logging
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -65,9 +65,8 @@ class AdaptiveProjectedGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
@@ -77,14 +76,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -148,44 +152,6 @@ class MomentumBuffer:
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def normalized_guidance(
pred_cond: torch.Tensor,
@@ -1,284 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedMixGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
used, and momentum buffer is updated).
enabled (`bool`, defaults to `True`):
Whether this guidance is enabled.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 3.5,
guidance_rescale: float = 0.0,
adaptive_projected_guidance_scale: float = 10.0,
adaptive_projected_guidance_momentum: float = -0.5,
adaptive_projected_guidance_rescale: float = 10.0,
eta: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
adaptive_projected_guidance_start_step: int = 5,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# no guidance
if not self._is_cfg_enabled():
pred = pred_cond
# CFG + update momentum buffer
elif not self._is_apg_enabled():
if self.momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
# CFG + update momentum buffer
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
# APG
elif self._is_apg_enabled():
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.adaptive_projected_guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled() or self._is_cfg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
if not self._is_cfg_enabled():
return False
is_within_range = False
if self._step is not None:
is_within_range = self._step > self.adaptive_projected_guidance_start_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
else:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
return is_within_range and not is_close
def get_state(self):
state = super().get_state()
state["momentum_buffer"] = self.momentum_buffer
state["is_apg_enabled"] = self._is_apg_enabled()
state["is_cfg_enabled"] = self._is_cfg_enabled()
return state
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def update_momentum_buffer(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
momentum_buffer: Optional[MomentumBuffer] = None,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
if momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
diff = momentum_buffer.running_average
else:
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
+9 -5
View File
@@ -72,9 +72,8 @@ class AutoGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
@@ -133,11 +132,16 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -27,50 +27,43 @@ if TYPE_CHECKING:
class ClassifierFreeGuidance(BaseGuidance):
"""
Implements Classifier-Free Guidance (CFG) for diffusion models.
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
Reference: https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
unconditional data, then combining predictions during inference. This allows trading off between quality (high
guidance) and diversity (low guidance).
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
**Two CFG Formulations:**
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
1. **Original formulation** (from paper):
```
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
```
Moves conditional predictions further from unconditional ones.
2. **Diffusers-native formulation** (default, from Imagen paper):
```
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
```
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
quality", "watermarks"). Equivalent in theory but more intuitive.
Use `use_original_formulation=True` to switch to the original formulation.
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
may reduce quality. Typical range: 1.0-20.0.
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
to 1.0 (full rescaling).
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
diffusers-native formulation from the Imagen paper.
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
steps.
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
steps.
enabled (`bool`, defaults to `True`):
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@@ -83,19 +76,23 @@ class ClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# YiYi Notes: add default behavior for self._enabled == False
if not self._enabled:
pred = pred_cond
elif self._step < self.zero_init_steps:
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
@@ -149,7 +149,6 @@ class FrequencyDecoupledGuidance(BaseGuidance):
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
enabled: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
@@ -161,7 +160,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop, enabled)
super().__init__(min_start, max_stop)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
@@ -218,11 +217,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})"
)
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
+49 -82
View File
@@ -40,11 +40,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
def __init__(self, start: float = 0.0, stop: float = 1.0):
self._start = start
self._stop = stop
self._step: int = None
@@ -52,7 +48,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = enabled
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
@@ -64,31 +60,6 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def new(self, **kwargs):
"""
Creates a copy of this guider instance, optionally with modified configuration parameters.
Args:
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
returns an exact copy with the same configuration.
Returns:
A new guider instance with the same (or updated) configuration.
Example:
```python
# Create a CFG guider
guider = ClassifierFreeGuidance(guidance_scale=3.5)
# Create an exact copy
same_guider = guider.new()
# Create a copy with different start step, keeping other config the same
new_guider = guider.new(guidance_scale=5)
```
"""
return self.__class__.from_config(self.config, **kwargs)
def disable(self):
self._enabled = False
@@ -101,52 +72,42 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep
self._count_prepared = 0
def get_state(self) -> Dict[str, Any]:
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
"""
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
the __repr__ method. Returns:
`Dict[str, Any]`: A dictionary containing the current state variables including:
- step: Current inference step
- num_inference_steps: Total number of inference steps
- timestep: Current timestep tensor
- count_prepared: Number of times prepare_models has been called
- enabled: Whether the guidance is enabled
- num_conditions: Number of conditions
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
conditional data identifier and the second element must be the unconditional data identifier or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
state = {
"step": self._step,
"num_inference_steps": self._num_inference_steps,
"timestep": self._timestep,
"count_prepared": self._count_prepared,
"enabled": self._enabled,
"num_conditions": self.num_conditions,
}
return state
def __repr__(self) -> str:
"""
Returns a string representation of the guidance object including both config and current state.
"""
# Get ConfigMixin's __repr__
str_repr = super().__repr__()
# Get current state
state = self.get_state()
# Format each state variable on its own line with indentation
state_lines = []
for k, v in state.items():
# Convert value to string and handle multi-line values
v_str = str(v)
if "\n" in v_str:
# For multi-line values (like MomentumBuffer), indent subsequent lines
v_lines = v_str.split("\n")
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
state_lines.append(f" {k}: {v_str}")
state_str = "\n".join(state_lines)
return f"{str_repr}\nState:\n{state_str}"
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
@@ -194,7 +155,8 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod
def _prepare_batch(
cls,
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
@@ -220,16 +182,21 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {}
for key, value in data.items():
for key, value in input_fields.items():
try:
if isinstance(value, torch.Tensor):
data_batch[key] = value
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = value[tuple_index]
data_batch[key] = getattr(data, value[tuple_index])
else:
raise ValueError(f"Invalid value type: {type(value)}")
except ValueError:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@@ -98,9 +98,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
@@ -169,7 +168,12 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -182,8 +186,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
+9 -5
View File
@@ -100,9 +100,8 @@ class SkipLayerGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
@@ -165,7 +164,12 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -178,8 +182,8 @@ class SkipLayerGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -92,9 +92,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
@@ -154,7 +153,12 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -167,8 +171,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -58,19 +58,23 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
-30
View File
@@ -108,7 +108,6 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
@@ -150,14 +149,6 @@ def _register_attention_processors_metadata():
),
)
# HunyuanImageAttnProcessor
AttentionProcessorRegistry.register(
model_class=HunyuanImageAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -171,10 +162,6 @@ def _register_transformer_blocks_metadata():
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_hunyuanimage import (
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -296,22 +283,6 @@ def _register_transformer_blocks_metadata():
),
)
# HunyuanImage2.1
TransformerBlockRegistry.register(
model_class=HunyuanImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanImageSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -337,5 +308,4 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
+5 -25
View File
@@ -1977,34 +1977,14 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
if any("head.head" in k for k in original_state_dict):
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
f"head.head.{lora_up_key}.weight"
)
if any("head.head" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
# an identity.
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
if f"head.head.{lora_down_key}.weight" in state_dict:
logger.info(
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
)
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
).T
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:
-10
View File
@@ -36,8 +36,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -84,7 +82,6 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -94,13 +91,11 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
@@ -137,8 +132,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -175,7 +168,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -189,7 +181,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
@@ -201,7 +192,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
@@ -27,8 +27,6 @@ if torch.distributed.is_available():
from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
@@ -49,7 +47,6 @@ if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
@@ -57,7 +54,6 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -82,12 +78,6 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
@@ -188,9 +178,6 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# `aiter`
AITER = "aiter"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
@@ -427,12 +414,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
@@ -1416,47 +1397,6 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
@@ -5,8 +5,6 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.
@@ -107,6 +107,9 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
from .vae import DecoderOutput, EncoderOutput
class ResBlock(nn.Module):
@@ -378,7 +378,7 @@ class Decoder(nn.Module):
return hidden_states
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629).
@@ -536,6 +536,27 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
@@ -32,10 +32,10 @@ from ..attention_processor import (
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -138,6 +138,35 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -28,7 +28,6 @@ from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
from .vae import AutoencoderMixin
class AllegroTemporalConvLayer(nn.Module):
@@ -674,7 +673,7 @@ class AllegroDecoder3D(nn.Module):
return sample
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -796,6 +795,35 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_size - self.tile_overlap_w,
)
def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module):
return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -1124,6 +1124,27 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -24,7 +24,7 @@ from ...utils import get_logger
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
from .vae import DecoderOutput, IdentityDistribution
logger = get_logger(__name__)
@@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module):
return hidden_states
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
@@ -1031,6 +1031,27 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
@@ -26,7 +26,7 @@ from ..activations import get_activation
from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -624,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -763,6 +763,27 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -1,709 +0,0 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageResnetBlock(nn.Module):
r"""
Residual block with two convolutions and optional channel change.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x):
# Apply shortcut connection
residual = x
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
# Add residual connection
return x + residual
class HunyuanImageAttentionBlock(nn.Module):
r"""
Self-attention with a single head.
Args:
in_channels (int): The number of channels in the input tensor.
"""
def __init__(self, in_channels: int):
super().__init__()
# layers
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
self.proj = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
identity = x
x = self.norm(x)
# compute query, key, value
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, height, width = query.shape
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
# apply attention
x = F.scaled_dot_product_attention(query, key, value)
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
# output projection
x = self.proj(x)
return x + identity
class HunyuanImageDownsample(nn.Module):
"""
Downsampling block for spatial reduction.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
if out_channels % factor != 0:
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.group_size = factor * in_channels // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
h = h.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = x.shape
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageUpsample(nn.Module):
"""
Upsampling block for spatial expansion.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
self.repeats = factor * out_channels // in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
h = h.reshape(B, C // 4, H * 2, W * 2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
B, C, H, W = shortcut.shape
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
return h + shortcut
class HunyuanImageMidBlock(nn.Module):
"""
Middle block for HunyuanImageVAE encoder and decoder.
Args:
in_channels (int): Number of input channels.
num_layers (int): Number of layers.
"""
def __init__(self, in_channels: int, num_layers: int = 1):
super().__init__()
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
attentions = []
for _ in range(num_layers):
attentions.append(HunyuanImageAttentionBlock(in_channels))
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnets[0](x)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
x = attn(x)
x = resnet(x)
return x
class HunyuanImageEncoder2D(nn.Module):
r"""
Encoder network that compresses input to latent representation.
Args:
in_channels (int): Number of input channels.
z_channels (int): Number of latent channels.
block_out_channels (list of int): Output channels for each block.
num_res_blocks (int): Number of residual blocks per block.
spatial_compression_ratio (int): Spatial downsampling factor.
non_linearity (str): Type of non-linearity to use. Default is "silu".
downsample_match_channel (bool): Whether to match channels during downsampling.
"""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
non_linearity: str = "silu",
downsample_match_channel: bool = True,
):
super().__init__()
if block_out_channels[-1] % (2 * z_channels) != 0:
raise ValueError(
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
)
self.in_channels = in_channels
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.spatial_compression_ratio = spatial_compression_ratio
self.group_size = block_out_channels[-1] // (2 * z_channels)
self.nonlinearity = get_activation(non_linearity)
# init block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
block_in_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
# residual blocks
for _ in range(num_res_blocks):
self.down_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# downsample block
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if downsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.down_blocks.append(
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# middle blocks
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
# output blocks
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
## downsamples
for down_block in self.down_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(down_block, x)
else:
x = down_block(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.mid_block, x)
else:
x = self.mid_block(x)
## head
B, C, H, W = x.shape
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
x = self.norm_out(x)
x = self.nonlinearity(x)
x = self.conv_out(x)
return x + residual
class HunyuanImageDecoder2D(nn.Module):
r"""
Decoder network that reconstructs output from latent representation.
Args:
z_channels : int
Number of latent channels.
out_channels : int
Number of output channels.
block_out_channels : Tuple[int, ...]
Output channels for each block.
num_res_blocks : int
Number of residual blocks per block.
spatial_compression_ratio : int
Spatial upsampling factor.
upsample_match_channel : bool
Whether to match channels during upsampling.
non_linearity (str): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
upsample_match_channel: bool = True,
non_linearity: str = "silu",
):
super().__init__()
if block_out_channels[0] % z_channels != 0:
raise ValueError(
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
)
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.repeat = block_out_channels[0] // z_channels
self.spatial_compression_ratio = spatial_compression_ratio
self.nonlinearity = get_activation(non_linearity)
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# Middle blocks with attention
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
# Upsampling blocks
block_in_channel = block_out_channels[0]
self.up_blocks = nn.ModuleList()
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
for _ in range(self.num_res_blocks + 1):
self.up_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if upsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
block_in_channel = block_out_channel
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid_block, h)
else:
h = self.mid_block(h)
for up_block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(up_block, h)
else:
h = up_block(h)
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# fmt: off
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
spatial_compression_ratio: int,
sample_size: int,
scaling_factor: float = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
) -> None:
# fmt: on
super().__init__()
self.encoder = HunyuanImageEncoder2D(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageDecoder2D(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
# Tiling and slicing configuration
self.use_slicing = False
self.use_tiling = False
# Tiling parameters
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // spatial_compression_ratio
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_size: Optional[int] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
allow processing larger images.
Args:
tile_sample_min_size (`int`, *optional*):
The minimum size required for a sample to be separated into tiles across the spatial dimension.
tile_overlap_factor (`float`, *optional*):
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
"""
self.use_tiling = True
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self.tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
batch_size, num_channels, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode input using spatial tiling strategy.
Args:
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
Returns:
`torch.Tensor`:
The latent representation of the encoded images.
"""
_, _, _, height, width = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode latent using spatial tiling strategy.
Args:
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, height, width = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -1,934 +0,0 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageRefinerCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
return self.conv(hidden_states)
class HunyuanImageRefinerRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.norm(x)
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, frames, height, width = query.shape
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
# batch_size, 1, frames * height * width, channels
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
x = self.proj_out(x)
return x + identity
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
@staticmethod
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
Args:
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
r1: temporal upsampling factor
r2: height upsampling factor
r3: width upsampling factor
"""
b, packed_c, f, h, w = tensor.shape
factor = r1 * r2 * r3
c = packed_c // factor
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
h = h[:, : h.shape[1] // 2]
# shortcut computation
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
else:
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
return h + shortcut
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
@staticmethod
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
This packs spatial/temporal dimensions into channels (opposite of upsample)
"""
b, c, packed_f, packed_h, packed_w = tensor.shape
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
h = torch.cat([h, h], dim=1)
# shortcut computation
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
else:
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageRefinerResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class HunyuanImageRefinerMidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
add_attention: bool = True,
) -> None:
super().__init__()
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
else:
attentions.append(None)
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanImageRefinerDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample_out_channels: Optional[int] = None,
add_temporal_downsample: int = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample_out_channels is not None:
self.downsamplers = nn.ModuleList(
[
HunyuanImageRefinerDownsampleDCAE(
out_channels,
out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample_out_channels: Optional[int] = None,
add_temporal_upsample: bool = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample_out_channels is not None:
self.upsamplers = nn.ModuleList(
[
HunyuanImageRefinerUpsampleDCAE(
out_channels,
out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerEncoder3D(nn.Module):
r"""
3D vae encoder for HunyuanImageRefiner.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 16,
downsample_match_channel: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.group_size = block_out_channels[-1] // self.out_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
output_channel = block_out_channels[i]
if not add_spatial_downsample:
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=None,
add_temporal_downsample=False,
)
input_channel = output_channel
else:
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
input_channel = downsample_out_channels
self.down_blocks.append(down_block)
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states += short_cut
return hidden_states
class HunyuanImageRefinerDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
upsample_match_channel: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.in_channels = in_channels
self.out_channels = out_channels
self.repeat = block_out_channels[0] // self.in_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
# up
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
output_channel = block_out_channels[i]
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
if add_spatial_upsample or add_temporal_upsample:
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
input_channel = upsample_out_channels
else:
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=None,
add_temporal_upsample=False,
)
input_channel = output_channel
self.up_blocks.append(up_block)
# out
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
scaling_factor: float = 1.03682,
) -> None:
super().__init__()
self.encoder = HunyuanImageRefinerEncoder3D(
in_channels=in_channels,
out_channels=latent_channels * 2,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageRefinerDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
x = self.encoder(x)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
class LTXVideoCausalConv3d(nn.Module):
@@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module):
return hidden_states
class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1219,6 +1219,27 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module):
return hidden_states
class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
@@ -805,6 +805,27 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def _encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module):
return hidden_states, new_conv_cache
class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLMochi(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models).
@@ -818,6 +818,27 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _enable_framewise_encoding(self):
r"""
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
@@ -31,7 +31,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -663,7 +663,7 @@ class QwenImageDecoder3d(nn.Module):
return x
class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -763,6 +763,27 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
def _count_conv3d(model):
count = 0
@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module):
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -25,7 +25,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
return x
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.downsampler(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.upsampler(x, feat_cache, feat_idx)
else:
x = self.upsampler(x)
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -951,7 +951,7 @@ def unpatchify(x, patch_size):
return x
class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -961,9 +961,6 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""
_supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(
@@ -1113,6 +1110,27 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1337,18 +1355,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width
if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = self.tile_sample_min_height - tile_sample_stride_height
blend_width = self.tile_sample_min_width - tile_sample_stride_width
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1362,9 +1371,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1380,15 +1387,11 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)
dec = torch.clamp(dec, min=-1.0, max=1.0)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@@ -25,7 +25,6 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
class Snake1d(nn.Module):
@@ -292,7 +291,7 @@ class OobleckDecoder(nn.Module):
return hidden_state
class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderOobleck(ModelMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
@@ -357,6 +356,20 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
self.use_slicing = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
@dataclass
@@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput):
latents: torch.Tensor
class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderTiny(ModelMixin, ConfigMixin):
r"""
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
@@ -162,6 +162,35 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
@@ -32,7 +32,7 @@ from ..attention_processor import (
)
from ..modeling_utils import ModelMixin
from ..unets.unet_2d import UNet2DModel
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution"
class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
r"""
The consistency decoder used with DALL-E 3.
@@ -167,6 +167,39 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
+3 -35
View File
@@ -286,9 +286,11 @@ class Decoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -296,6 +298,7 @@ class Decoder(nn.Module):
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -891,38 +894,3 @@ class DecoderTiny(nn.Module):
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
class AutoencoderMixin:
def enable_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
if not hasattr(self, "use_tiling"):
raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.")
self.use_tiling = True
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
if not hasattr(self, "use_slicing"):
raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.")
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@@ -22,7 +22,6 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
@dataclass
@@ -38,7 +37,7 @@ class VQEncoderOutput(BaseOutput):
latents: torch.Tensor
class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
class VQModel(ModelMixin, ConfigMixin):
r"""
A VQ-VAE model for decoding latent representations.
+2 -10
View File
@@ -319,17 +319,13 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -345,11 +341,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# Auto-detect appropriate dtype if not specified
if dtype is None:
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
-1
View File
@@ -251,7 +251,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_repeated_blocks = []
_parallel_config = None
_cp_plan = None
_skip_keys = None
def __init__(self):
super().__init__()
@@ -18,7 +18,6 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
@@ -28,13 +27,11 @@ if is_torch_available():
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
@@ -1,655 +0,0 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
class BriaFiboAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "BriaFiboAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = BriaFiboAttnProcessor
_available_processors = [BriaFiboAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class BriaFiboEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class BriaFiboSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = BriaAttnProcessor()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class BriaFiboTextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
@maybe_allow_in_graph
# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
class BriaFiboTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = BriaFiboAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=BriaFiboAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class BriaFiboTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class BriaFiboTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = BriaFiboTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
...
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
text_encoder_dim: int = 2048,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BriaFiboTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
BriaFiboSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
caption_projection = [
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
for i in range(self.config.num_layers + self.config.num_single_layers)
]
self.caption_projection = nn.ModuleList(caption_projection)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
text_encoder_layers: list = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 3:
txt_ids = txt_ids[0]
if len(img_ids.shape) == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
new_text_encoder_layers = []
for i, text_encoder_layer in enumerate(text_encoder_layers):
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
new_text_encoder_layers.append(text_encoder_layer)
text_encoder_layers = new_text_encoder_layers
block_id = 0
for index_block, block in enumerate(self.transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
"""
The Transformer model introduced in Flux, modified for Chroma.
Reference: https://huggingface.co/lodestones/Chroma1-HD
Reference: https://huggingface.co/lodestones/Chroma
Args:
patch_size (`int`, defaults to `1`):
@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,11 +717,7 @@ class FluxTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
@@ -1,971 +0,0 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepTextProjEmbeddings,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
apply_rotary_emb(
query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
),
query[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
key = torch.cat(
[
apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
key[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
else:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
# 5. Attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class HunyuanImagePatchEmbed(nn.Module):
def __init__(
self,
patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.patch_size = patch_size
if len(patch_size) == 2:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
elif len(patch_size) == 3:
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
else:
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
class HunyuanImageByT5TextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, out_features: int):
super().__init__()
self.norm = nn.LayerNorm(in_features)
self.linear_1 = nn.Linear(in_features, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, out_features)
self.act_fn = nn.GELU()
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(encoder_hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_3(hidden_states)
return hidden_states
class HunyuanImageAdaNorm(nn.Module):
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
self.linear = nn.Linear(in_features, out_features)
self.nonlinearity = nn.SiLU()
def forward(
self, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
return gate_msa, gate_mlp
class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
guidance_embeds: bool = False,
use_meanflow: bool = False,
):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_meanflow = use_meanflow
self.time_proj_r = None
self.timestep_embedder_r = None
if use_meanflow:
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = None
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self,
timestep: torch.Tensor,
timestep_r: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
if timestep_r is not None:
timesteps_proj_r = self.time_proj_r(timestep_r)
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
conditioning = timesteps_emb + guidance_emb
else:
conditioning = timesteps_emb
return conditioning
# IndividualTokenRefinerBlock
@maybe_allow_in_graph
class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int, # 28
attention_head_dim: int, # 128
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
)
gate_msa, gate_mlp = self.norm_out(temb)
hidden_states = hidden_states + attn_output * gate_msa
ff_output = self.ff(self.norm2(hidden_states))
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
class HunyuanImageIndividualTokenRefiner(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
self.refiner_blocks = nn.ModuleList(
[
HunyuanImageIndividualTokenRefinerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
batch_size = attention_mask.shape[0]
seq_len = attention_mask.shape[1]
attention_mask = attention_mask.to(hidden_states.device)
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True
for block in self.refiner_blocks:
hidden_states = block(hidden_states, temb, self_attn_mask)
return hidden_states
# txt_in
class HunyuanImageTokenRefiner(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=hidden_size, pooled_projection_dim=in_channels
)
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
self.token_refiner = HunyuanImageIndividualTokenRefiner(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=num_layers,
mlp_width_ratio=mlp_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_hidden_states = hidden_states.mean(dim=1)
else:
original_dtype = hidden_states.dtype
mask_float = attention_mask.float().unsqueeze(-1)
pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
pooled_hidden_states = pooled_hidden_states.to(original_dtype)
temb = self.time_text_embed(timestep, pooled_hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
return hidden_states
class HunyuanImageRotaryPosEmbed(nn.Module):
def __init__(
self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
) -> None:
super().__init__()
if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
if not len(patch_size) == len(rope_dim):
raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
self.patch_size = patch_size
self.rope_dim = rope_dim
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hidden_states.ndim == 5:
_, _, frame, height, width = hidden_states.shape
patch_size_frame, patch_size_height, patch_size_width = self.patch_size
rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
elif hidden_states.ndim == 4:
_, _, height, width = hidden_states.shape
patch_size_height, patch_size_width = self.patch_size
rope_sizes = [height // patch_size_height, width // patch_size_width]
else:
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
axes_grids = []
for i in range(len(rope_sizes)):
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
axes_grids.append(grid)
grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
grid = torch.stack(grid, dim=0) # [2, H, W]
freqs = []
for i in range(len(rope_sizes)):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class HunyuanImageSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanImageAttnProcessor(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class HunyuanImageTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanImageAttnProcessor(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 3. Modulation and residual connection
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_single_layers (`int`, defaults to `40`):
The number of layers of single-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
image_condition_type (`str`, *optional*, defaults to `None`):
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
tokens in the latent stream and apply conditioning.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
"HunyuanImagePatchEmbed",
"HunyuanImageTokenRefiner",
]
_repeated_blocks = [
"HunyuanImageTransformerBlock",
"HunyuanImageSingleTransformerBlock",
]
@register_to_config
def __init__(
self,
in_channels: int = 64,
out_channels: int = 64,
num_attention_heads: int = 28,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: Tuple[int, int] = (1, 1),
qk_norm: str = "rms_norm",
guidance_embeds: bool = False,
text_embed_dim: int = 3584,
text_embed_2_dim: Optional[int] = None,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (64, 64),
use_meanflow: bool = False,
) -> None:
super().__init__()
if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
self.context_embedder = HunyuanImageTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
if text_embed_2_dim is not None:
self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
else:
self.context_embedder_2 = None
self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
# 2. RoPE
self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanImageTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanImageSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
timestep_r: Optional[torch.LongTensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
encoder_attention_mask_2: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
sizes = (height, width)
elif hidden_states.ndim == 5:
batch_size, channels, frame, height, width = hidden_states.shape
sizes = (frame, height, width)
else:
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
encoder_attention_mask = encoder_attention_mask.bool()
temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
# reorder and combine text tokens: combine valid tokens first, then padding
new_encoder_hidden_states = []
new_encoder_attention_mask = []
for text, text_mask, text_2, text_mask_2 in zip(
encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
):
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
new_encoder_hidden_states.append(
torch.cat(
[
text_2[text_mask_2], # valid byt5
text[text_mask], # valid mllm
text_2[~text_mask_2], # invalid byt5
text[~text_mask], # invalid mllm
],
dim=0,
)
)
# Apply same reordering to attention masks
new_encoder_attention_mask.append(
torch.cat(
[
text_mask_2[text_mask_2],
text_mask[text_mask],
text_mask_2[~text_mask_2],
text_mask[~text_mask],
],
dim=0,
)
)
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# 3. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 4. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# 5. unpatchify
# reshape: [batch_size, *post_patch_dims, channels, *patch_size]
out_channels = self.config.out_channels
reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
hidden_states = hidden_states.reshape(*reshape_dims)
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
ndim = len(post_patch_sizes)
permute_pattern = [0, ndim + 1] # batch, channels
for i in range(ndim):
permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
hidden_states = hidden_states.permute(*permute_pattern)
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
final_dims = [batch_size, out_channels] + [
post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
]
hidden_states = hidden_states.reshape(*final_dims)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
@@ -324,7 +324,6 @@ class Kandinsky5AttnProcessor:
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
else:
attn_mask = None
@@ -336,7 +335,6 @@ class Kandinsky5AttnProcessor:
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(-2, -1)
attn_out = attn.out_layer(hidden_states)
@@ -1,770 +0,0 @@
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn.functional import fold, unfold
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__)
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
r"""
Generates 2D patch coordinate indices for a batch of images.
Args:
batch_size (`int`):
Number of images in the batch.
height (`int`):
Height of the input images (in pixels).
width (`int`):
Width of the input images (in pixels).
patch_size (`int`):
Size of the square patches that the image is divided into.
device (`torch.device`):
The device on which to create the tensor.
Returns:
`torch.Tensor`:
Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
image grid.
"""
img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
r"""
Applies rotary positional embeddings (RoPE) to a query tensor.
Args:
xq (`torch.Tensor`):
Input tensor of shape `(..., dim)` representing the queries.
freqs_cis (`torch.Tensor`):
Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
Returns:
`torch.Tensor`:
Tensor of the same shape as `xq` with rotary embeddings applied.
"""
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
# Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq)
class PRXAttnProcessor2_0:
r"""
Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "PRXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Apply PRX attention using PRXAttention module.
Args:
attn: PRXAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D]
attention_mask: Boolean mask for text tokens [B, L_txt]
image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
"""
if encoder_hidden_states is None:
raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
# Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states)
B, L_img, _ = img_qkv.shape
img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
# Apply QK normalization to image tokens
img_q = attn.norm_q(img_q)
img_k = attn.norm_k(img_k)
# Project text tokens to K, V
txt_kv = attn.txt_kv_proj(encoder_hidden_states)
B, L_txt, _ = txt_kv.shape
txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
txt_k, txt_v = txt_kv[0], txt_kv[1]
# Apply K normalization to text tokens
txt_k = attn.norm_added_k(txt_k)
# Apply RoPE to image queries and keys
if image_rotary_emb is not None:
img_q = apply_rope(img_q, image_rotary_emb)
img_k = apply_rope(img_k, image_rotary_emb)
# Concatenate text and image keys/values
k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
# Build attention mask if provided
attn_mask_tensor = None
if attention_mask is not None:
bs, _, l_img, _ = img_q.shape
l_txt = txt_k.shape[2]
if attention_mask.dim() != 2:
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
if attention_mask.shape[-1] != l_txt:
raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
device = img_q.device
ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
# Apply attention using dispatch_attention_fn for backend support
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
query = img_q.transpose(1, 2) # [B, L_img, H, D]
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attn_mask_tensor,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
batch_size, seq_len, num_heads, head_dim = attn_output.shape
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
# Apply output projection
attn_output = attn.to_out[0](attn_output)
if len(attn.to_out) > 1:
attn_output = attn.to_out[1](attn_output) # dropout if present
return attn_output
class PRXAttention(nn.Module, AttentionModuleMixin):
r"""
PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
PRX's architecture.
"""
_default_processor_cls = PRXAttnProcessor2_0
_available_processors = [PRXAttnProcessor2_0]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
bias: bool = False,
out_bias: bool = False,
eps: float = 1e-6,
processor=None,
):
super().__init__()
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(0.0))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
**kwargs,
)
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PRXEmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
dimension. The embeddings are combined and returned as a single tensor
Args:
dim (int):
Base embedding dimension (must be even).
theta (int):
Scaling factor that controls the frequency spectrum of the rotary embeddings.
axes_dim (list[int]):
List of embedding dimensions for each axis (each must be even).
"""
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
out = out.reshape(*out.shape[:-1], 2, 2)
return out.float()
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
class MLPEmbedder(nn.Module):
r"""
A simple 2-layer MLP used for embedding inputs.
Args:
in_dim (`int`):
Dimensionality of the input features.
hidden_dim (`int`):
Dimensionality of the hidden and output embedding space.
Returns:
`torch.Tensor`:
Tensor of shape `(..., hidden_dim)` containing the embedded representations.
"""
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class Modulation(nn.Module):
r"""
Modulation network that generates scale, shift, and gating parameters.
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
two tuples `(shift, scale, gate)`.
Args:
dim (`int`):
Dimensionality of the input vector. The output will have `6 * dim` features internally.
Returns:
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
Two tuples `(shift, scale, gate)`.
"""
def __init__(self, dim: int):
super().__init__()
self.lin = nn.Linear(dim, 6 * dim, bias=True)
nn.init.constant_(self.lin.weight, 0)
nn.init.constant_(self.lin.bias, 0)
def forward(
self, vec: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
return tuple(out[:3]), tuple(out[3:])
class PRXBlock(nn.Module):
r"""
Multimodal transformer block with textimage cross-attention, modulation, and MLP.
Args:
hidden_size (`int`):
Dimension of the hidden representations.
num_heads (`int`):
Number of attention heads.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Expansion ratio for the hidden dimension inside the MLP.
qk_scale (`float`, *optional*):
Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
Attributes:
img_pre_norm (`nn.LayerNorm`):
Pre-normalization applied to image tokens before attention.
attention (`PRXAttention`):
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
image and text tokens.
post_attention_layernorm (`nn.LayerNorm`):
Normalization applied after attention.
gate_proj / up_proj / down_proj (`nn.Linear`):
Feedforward layers forming the gated MLP.
mlp_act (`nn.GELU`):
Nonlinear activation used in the MLP.
modulation (`Modulation`):
Produces scale/shift/gating parameters for modulated layers.
Methods:
The forward method performs cross-attention and the MLP with modulation.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.hidden_size = hidden_size
# Pre-attention normalization for image tokens
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# PRXAttention module with built-in projections and norms
self.attention = PRXAttention(
query_dim=hidden_size,
heads=num_heads,
dim_head=self.head_dim,
bias=False,
out_bias=False,
eps=1e-6,
processor=PRXAttnProcessor2_0(),
)
# mlp
self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.
Args:
hidden_states (`torch.Tensor`):
Image tokens of shape `(B, L_img, hidden_size)`.
encoder_hidden_states (`torch.Tensor`):
Text tokens of shape `(B, L_txt, hidden_size)`.
temb (`torch.Tensor`):
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
broadcastable).
image_rotary_emb (`torch.Tensor`):
Rotary positional embeddings applied inside attention.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
**kwargs:
Additional keyword arguments for API compatibility.
Returns:
`torch.Tensor`:
Updated image tokens of shape `(B, L_img, hidden_size)`.
"""
mod_attn, mod_mlp = self.modulation(temb)
attn_shift, attn_scale, attn_gate = mod_attn
mlp_shift, mlp_scale, mlp_gate = mod_mlp
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
attn_out = self.attention(
hidden_states=hidden_states_mod,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_gate * attn_out
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
return hidden_states
class FinalLayer(nn.Module):
r"""
Final projection layer with adaptive LayerNorm modulation.
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
outputs.
Args:
hidden_size (`int`):
Dimensionality of the input tokens.
patch_size (`int`):
Size of the square image patches.
out_channels (`int`):
Number of output channels per pixel (e.g. RGB = 3).
Forward Inputs:
x (`torch.Tensor`):
Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
vec (`torch.Tensor`):
Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
LayerNorm.
Returns:
`torch.Tensor`:
Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
"""
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
r"""
Flattens an image tensor into a sequence of non-overlapping patches.
Args:
img (`torch.Tensor`):
Input image tensor of shape `(B, C, H, W)`.
patch_size (`int`):
Size of each square patch. Must evenly divide both `H` and `W`.
Returns:
`torch.Tensor`:
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
// patch_size)` is the number of patches.
"""
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
r"""
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
Args:
seq (`torch.Tensor`):
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
patch_size)`.
patch_size (`int`):
Size of each square patch.
shape (`tuple` or `torch.Tensor`):
The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
height and width.
Returns:
`torch.Tensor`:
Reconstructed image tensor of shape `(B, C, H, W)`.
"""
if isinstance(shape, tuple):
shape = shape[-2:]
elif isinstance(shape, torch.Tensor):
shape = (int(shape[0]), int(shape[1]))
else:
raise NotImplementedError(f"shape type {type(shape)} not supported")
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation.
Args:
in_channels (`int`, *optional*, defaults to 16):
Number of input channels in the latent image.
patch_size (`int`, *optional*, defaults to 2):
Size of the square patches used to flatten the input image.
context_in_dim (`int`, *optional*, defaults to 2304):
Dimensionality of the text conditioning input.
hidden_size (`int`, *optional*, defaults to 1792):
Dimension of the hidden representation.
mlp_ratio (`float`, *optional*, defaults to 3.5):
Expansion ratio for the hidden dimension inside MLP blocks.
num_heads (`int`, *optional*, defaults to 28):
Number of attention heads.
depth (`int`, *optional*, defaults to 16):
Number of transformer blocks.
axes_dim (`list[int]`, *optional*):
List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
theta (`int`, *optional*, defaults to 10000):
Frequency scaling factor for rotary embeddings.
time_factor (`float`, *optional*, defaults to 1000.0):
Scaling factor applied in timestep embeddings.
time_max_period (`int`, *optional*, defaults to 10000):
Maximum frequency period for timestep embeddings.
Attributes:
pe_embedder (`EmbedND`):
Multi-axis rotary embedding generator for positional encodings.
img_in (`nn.Linear`):
Projection layer for image patch tokens.
time_in (`MLPEmbedder`):
Embedding layer for timestep embeddings.
txt_in (`nn.Linear`):
Projection layer for text conditioning.
blocks (`nn.ModuleList`):
Stack of transformer blocks (`PRXBlock`).
final_layer (`LastLayer`):
Projection layer mapping hidden tokens back to patch outputs.
Methods:
attn_processors:
Returns a dictionary of all attention processors in the model.
set_attn_processor(processor):
Replaces attention processors across all attention layers.
process_inputs(image_latent, txt):
Converts inputs into patch tokens, encodes text, and produces positional encodings.
compute_timestep_embedding(timestep, dtype):
Creates a timestep embedding of dimension 256, scaled and projected.
forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
**block_kwargs):
Runs the sequence of transformer blocks over image and text tokens.
forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
attention_kwargs=None, return_dict=True):
Full forward pass from latent input to reconstructed output image.
Returns:
`Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
- `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
"""
config_name = "config.json"
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
patch_size: int = 2,
context_in_dim: int = 2304,
hidden_size: int = 1792,
mlp_ratio: float = 3.5,
num_heads: int = 28,
depth: int = 16,
axes_dim: list = None,
theta: int = 10000,
time_factor: float = 1000.0,
time_max_period: int = 10000,
):
super().__init__()
if axes_dim is None:
axes_dim = [32, 32]
# Store parameters directly
self.in_channels = in_channels
self.patch_size = patch_size
self.out_channels = self.in_channels * self.patch_size**2
self.time_factor = time_factor
self.time_max_period = time_max_period
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.blocks = nn.ModuleList(
[
PRXBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
)
for i in range(depth)
]
)
self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
self.gradient_checkpointing = False
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return self.time_in(
get_timestep_embedding(
timesteps=timestep,
embedding_dim=256,
max_period=self.time_max_period,
scale=self.time_factor,
flip_sin_to_cos=True, # Match original cos, sin order
).to(dtype)
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
r"""
Forward pass of the PRXTransformer2DModel.
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
Args:
hidden_states (`torch.Tensor`):
Input latent image tensor of shape `(B, C, H, W)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
encoder_hidden_states (`torch.Tensor`):
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
attention_kwargs (`dict`, *optional*):
Additional arguments passed to attention layers.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a `Transformer2DModelOutput` or a tuple.
Returns:
`Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
"""
# Process text conditioning
txt = self.txt_in(encoder_hidden_states)
# Convert image to sequence and embed
img = img2seq(hidden_states, self.patch_size)
img = self.img_in(img)
# Generate positional embeddings
bs, _, h, w = hidden_states.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
pe = self.pe_embedder(img_ids)
# Compute time embedding
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
# Apply transformer blocks
for block in self.blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img = self._gradient_checkpointing_func(
block.__call__,
img,
txt,
vec,
pe,
attention_mask,
)
else:
img = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=pe,
attention_mask=attention_mask,
)
# Final layer and convert back to image
img = self.final_layer(img, vec)
output = seq2img(img, self.patch_size, hidden_states.shape)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -130,14 +130,8 @@ class PipelineState:
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
intermediates dict.
"""
# Use object.__getattribute__ to avoid infinite recursion during deepcopy
try:
values = object.__getattribute__(self, "values")
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
if name in values:
return values[name]
if name in self.values:
return self.values[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self):
@@ -305,7 +299,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"cache_dir",
"force_download",
"local_files_only",
"local_dir",
"proxies",
"resume_download",
"revision",
@@ -332,10 +325,11 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
block_kwargs = {
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
}
return block_cls(**block_kwargs)
@@ -2498,8 +2492,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
if state is None:
state = PipelineState()
else:
state = deepcopy(state)
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
@@ -238,27 +238,19 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"encoder_hidden_states_mask": (
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
@@ -336,27 +328,19 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"encoder_hidden_states_mask": (
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
@@ -201,41 +201,27 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"time_ids": (
getattr(block_state, "add_time_ids", None),
getattr(block_state, "negative_add_time_ids", None),
),
"text_embeds": (
getattr(block_state, "pooled_prompt_embeds", None),
getattr(block_state, "negative_pooled_prompt_embeds", None),
),
"image_embeds": (
getattr(block_state, "ip_adapter_embeds", None),
getattr(block_state, "negative_ip_adapter_embeds", None),
),
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
@@ -358,23 +344,11 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"time_ids": (
getattr(block_state, "add_time_ids", None),
getattr(block_state, "negative_add_time_ids", None),
),
"text_embeds": (
getattr(block_state, "pooled_prompt_embeds", None),
getattr(block_state, "negative_pooled_prompt_embeds", None),
),
"image_embeds": (
getattr(block_state, "ip_adapter_embeds", None),
getattr(block_state, "negative_ip_adapter_embeds", None),
),
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
# cond_scale for the timestep (controlnet input)
@@ -395,15 +369,12 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
+10 -15
View File
@@ -94,30 +94,25 @@ class WanLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
-6
View File
@@ -128,7 +128,6 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -145,7 +144,6 @@ else:
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -242,7 +240,6 @@ else:
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
]
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -563,7 +560,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
@@ -643,7 +639,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder,
)
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
@@ -724,7 +719,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
@@ -1,48 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,838 +0,0 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Example:
```python
import torch
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
pipe = BriaFiboPipeline.from_pretrained(
"briaai/FIBO",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
with torch.inference_mode():
# 1. Create a prompt to generate an initial image
output = vlm_pipe(prompt="a beautiful dog")
json_prompt_generate = output.values["json_prompt"]
# Generate the image from the structured json prompt
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
results_generate.images[0].save("image_generate.png")
```
"""
class BriaFiboPipeline(DiffusionPipeline):
r"""
Args:
transformer (`BriaFiboTransformer2DModel`):
The transformer model for 2D diffusion modeling.
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
Scheduler to be used with `transformer` to denoise the encoded latents.
vae (`AutoencoderKLWan`):
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
text_encoder (`SmolLM3ForCausalLM`):
Text encoder for processing input prompts.
tokenizer (`AutoTokenizer`):
Tokenizer used for processing the input text prompts for the text_encoder.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: BriaFiboTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKLWan,
text_encoder: SmolLM3ForCausalLM,
tokenizer: AutoTokenizer,
):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 64
def get_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
if not prompt:
raise ValueError("`prompt` must be a non-empty string or list of strings.")
batch_size = len(prompt)
bot_token_id = 128000
text_encoder_device = device if device is not None else torch.device("cpu")
if not isinstance(text_encoder_device, torch.device):
text_encoder_device = torch.device(text_encoder_device)
if all(p == "" for p in prompt):
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
attention_mask = torch.ones_like(input_ids)
else:
tokenized = self.tokenizer(
prompt,
padding="longest",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = tokenized.input_ids.to(text_encoder_device)
attention_mask = tokenized.attention_mask.to(text_encoder_device)
if any(p == "" for p in prompt):
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
input_ids[empty_rows] = bot_token_id
attention_mask[empty_rows] = 1
encoder_outputs = self.text_encoder(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_outputs.hidden_states
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hidden_states = tuple(
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
)
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
return prompt_embeds, hidden_states, attention_mask
@staticmethod
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
batch_size, seq_len, dim = prompt_embeds.shape
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
else:
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if max_tokens < seq_len:
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
if max_tokens > seq_len:
pad_length = max_tokens - seq_len
padding = torch.zeros(
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
mask_padding = torch.zeros(
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
return prompt_embeds, attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 3000,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
guidance_scale (`float`):
Guidance scale for classifier free guidance.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
prompt_attention_mask = None
negative_prompt_attention_mask = None
if prompt_embeds is None:
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
if guidance_scale > 1:
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
negative_prompt = ""
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
# Pad to longest
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if negative_prompt_embeds is not None:
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
)
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
)
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
else:
max_tokens = prompt_embeds.shape[1]
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
negative_prompt_layers = None
dtype = self.text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
return (
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@staticmethod
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels)
latents = latents.permute(0, 3, 1, 2)
return latents
@staticmethod
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
latents = latents.permute(0, 2, 3, 1)
latents = latents.reshape(batch_size, height * width, num_channels_latents)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
do_patching=False,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if do_patching:
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
else:
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@staticmethod
def _prepare_attention_mask(attention_mask):
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
# convert to 0 - keep, -inf ignore
attention_matrix = torch.where(
attention_matrix == 1, 0.0, -torch.inf
) # Apply -inf to ignored tokens for nulling softmax score
return attention_matrix
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 3000,
do_patching=False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
Examples:
Returns:
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
num_images_per_prompt=num_images_per_prompt,
lora_scale=lora_scale,
)
prompt_batch_size = prompt_embeds.shape[0]
if guidance_scale > 1:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_layers = [
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
]
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
self.transformer.single_transformer_blocks
)
if len(prompt_layers) >= total_num_layers_transformer:
# remove first layers
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
else:
# duplicate last layer
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if do_patching:
num_channels_latents = int(num_channels_latents / 4)
latents, latent_image_ids = self.prepare_latents(
prompt_batch_size,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
do_patching,
)
latent_attention_mask = torch.ones(
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
)
if guidance_scale > 1:
latent_attention_mask = latent_attention_mask.repeat(2, 1)
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
if self._joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
self._joint_attention_kwargs["attention_mask"] = attention_mask
# Adapt scheduler to dynamic shifting (resolution dependent)
if do_patching:
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
else:
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Support old different diffusers versions
if len(latent_image_ids.shape) == 3:
latent_image_ids = latent_image_ids[0]
if len(text_ids.shape) == 3:
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(
device=latent_model_input.device, dtype=latent_model_input.dtype
)
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_encoder_layers=prompt_layers,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if guidance_scale > 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
if do_patching:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
else:
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
latents = latents.unsqueeze(dim=2)
latents_device = latents[0].device
latents_dtype = latents[0].dtype
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents_device, latents_dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_device, latents_dtype
)
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
latents_scaled = torch.cat(latents_scaled, dim=0)
image = []
for scaled_latent in latents_scaled:
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
image.append(curr_image)
if len(image) == 1:
image = image[0]
else:
image = np.stack(image, axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return BriaFiboPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if max_sequence_length is not None and max_sequence_length > 3000:
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class BriaFiboPipelineOutput(BaseOutput):
"""
Output class for BriaFibo pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
@@ -158,7 +158,7 @@ class ChromaPipeline(
r"""
The Chroma pipeline for text-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma1-HD/
Reference: https://huggingface.co/lodestones/Chroma/
Args:
transformer ([`ChromaTransformer2DModel`]):
@@ -233,23 +233,20 @@ class ChromaPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokenizer_mask = text_inputs.attention_mask
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask_device = tokenizer_mask.to(device)
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
)[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... model_id,
... transformer=transformer,
@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
r"""
The Chroma pipeline for image-to-image generation.
Reference: https://huggingface.co/lodestones/Chroma1-HD/
Reference: https://huggingface.co/lodestones/Chroma/
Args:
transformer ([`ChromaTransformer2DModel`]):
@@ -247,21 +247,20 @@ class ChromaImg2ImgPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokenizer_mask = text_inputs.attention_mask
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask_device = tokenizer_mask.to(device)
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
)[0]
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
max_sequence_length,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
max_sequence_length,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -1,50 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
_import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_hunyuanimage import HunyuanImagePipeline
from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,866 +0,0 @@
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
from ...guiders import AdaptiveProjectedMixGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanImagePipeline
>>> pipe = HunyuanImagePipeline.from_pretrained(
... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
>>> image.save("hunyuanimage.png")
```
"""
def extract_glyph_text(prompt: str):
"""
Extract text enclosed in quotes for glyph rendering.
Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
Args:
prompt: Input text prompt
Returns:
Formatted glyph text string or None if no quoted text found
"""
text_prompt_texts = []
pattern_quote_single = r"\'(.*?)\'"
pattern_quote_double = r"\"(.*?)\""
pattern_quote_chinese_single = r"(.*?)"
pattern_quote_chinese_double = r"“(.*?)”"
matches_quote_single = re.findall(pattern_quote_single, prompt)
matches_quote_double = re.findall(pattern_quote_double, prompt)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
text_prompt_texts.extend(matches_quote_single)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if text_prompt_texts:
glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
else:
glyph_text_formatted = None
return glyph_text_formatted
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HunyuanImagePipeline(DiffusionPipeline):
r"""
The HunyuanImage pipeline for text-to-image generation.
Args:
transformer ([`HunyuanImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanImage`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
text_encoder_2 ([`T5EncoderModel`]):
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
variant.
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
guider ([`AdaptiveProjectedMixGuidance`]):
[AdaptiveProjectedMixGuidance]to be used to guide the image generation.
ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
[AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["ocr_guider", "guider"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLHunyuanImage,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: ByT5Tokenizer,
transformer: HunyuanImageTransformer2DModel,
guider: Optional[AdaptiveProjectedMixGuidance] = None,
ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
guider=guider,
ocr_guider=ocr_guider,
)
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = 1000
self.tokenizer_2_max_length = 128
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.prompt_template_encode_start_idx = 34
self.default_sample_size = 64
def _get_qwen_prompt_embeds(
self,
tokenizer: Qwen2Tokenizer,
text_encoder: Qwen2_5_VLForConditionalGeneration,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 1000,
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
drop_idx: int = 34,
hidden_state_skip_layer: int = 2,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
prompt_embeds = prompt_embeds[:, drop_idx:]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = encoder_attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def _get_byt5_prompt_embeds(
self,
tokenizer: ByT5Tokenizer,
text_encoder: T5EncoderModel,
prompt: str,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 128,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
if isinstance(prompt, list):
raise ValueError("byt5 prompt should be a string")
elif prompt is None:
raise ValueError("byt5 prompt should not be None")
txt_tokens = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to(device)
prompt_embeds = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask.float(),
)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
batch_size: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
template=self.prompt_template_encode,
drop_idx=self.prompt_template_encode_start_idx,
)
if prompt_embeds_2 is None:
prompt_embeds_2_list = []
prompt_embeds_mask_2_list = []
glyph_texts = [extract_glyph_text(p) for p in prompt]
for glyph_text in glyph_texts:
if glyph_text is None:
glyph_text_embeds = torch.zeros(
(1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
)
glyph_text_embeds_mask = torch.zeros(
(1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
)
else:
glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
tokenizer=self.tokenizer_2,
text_encoder=self.text_encoder_2,
prompt=glyph_text,
device=device,
tokenizer_max_length=self.tokenizer_2_max_length,
)
prompt_embeds_2_list.append(glyph_text_embeds)
prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
_, seq_len_2, _ = prompt_embeds_2.shape
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
prompt_embeds_2=None,
prompt_embeds_mask_2=None,
negative_prompt_embeds_2=None,
negative_prompt_embeds_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
raise ValueError(
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
)
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
raise ValueError(
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
distilled_guidance_scale: Optional[float] = 3.25,
sigmas: Optional[List[float]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
not provided, will use an empty negative prompt. Ignored when not using guidance. ).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
distilled_guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
ignored.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, text embeddings mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
`negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
negative_prompt_embeds_2=negative_prompt_embeds_2,
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. prepare prompt embeds
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds_2=prompt_embeds_2,
prompt_embeds_mask_2=prompt_embeds_mask_2,
)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
# select guider
if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
# prompt contains ocr and pipeline has a guider for ocr
guider = self.ocr_guider
elif self.guider is not None:
guider = self.guider
# distilled model does not use guidance method, use default guider with enabled=False
else:
guider = AdaptiveProjectedMixGuidance(enabled=False)
if guider._enabled and guider.num_conditions > 1:
(
negative_prompt_embeds,
negative_prompt_embeds_mask,
negative_prompt_embeds_2,
negative_prompt_embeds_mask_2,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds_2=negative_prompt_embeds_2,
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance (for guidance-distilled model)
if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
if self.transformer.config.guidance_embeds:
guidance = (
torch.tensor(
[distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
)
* 1000.0
)
else:
guidance = None
if self.attention_kwargs is None:
self._attention_kwargs = {}
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if self.transformer.config.use_meanflow:
if i == len(timesteps) - 1:
timestep_r = torch.tensor([0.0], device=device)
else:
timestep_r = timesteps[i + 1]
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
else:
timestep_r = None
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
}
# Step 2: Update guider's internal state for this denoising step
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep,
timestep_r=timestep_r,
guidance=guidance,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HunyuanImagePipelineOutput(images=image)
@@ -1,752 +0,0 @@
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from ...guiders import AdaptiveProjectedMixGuidance
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanImageRefinerPipeline
>>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = load_image("path/to/image.png")
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
>>> image.save("hunyuanimage.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class HunyuanImageRefinerPipeline(DiffusionPipeline):
r"""
The HunyuanImage pipeline for text-to-image generation.
Args:
transformer ([`HunyuanImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanImageRefiner`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["guider"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLHunyuanImageRefiner,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2Tokenizer,
transformer: HunyuanImageTransformer2DModel,
guider: Optional[AdaptiveProjectedMixGuidance] = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
guider=guider,
)
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = 256
self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
self.prompt_template_encode_start_idx = 36
self.default_sample_size = 64
self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
def _get_qwen_prompt_embeds(
self,
tokenizer: Qwen2Tokenizer,
text_encoder: Qwen2_5_VLForConditionalGeneration,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tokenizer_max_length: int = 1000,
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
drop_idx: int = 34,
hidden_state_skip_layer: int = 2,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
prompt_embeds = prompt_embeds[:, drop_idx:]
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
encoder_attention_mask = encoder_attention_mask.to(device=device)
return prompt_embeds, encoder_attention_mask
def encode_prompt(
self,
prompt: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
batch_size: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
batch_size (`int`):
batch size of prompts, defaults to 1
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
argument.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
argument using self.tokenizer_2 and self.text_encoder_2.
"""
device = device or self._execution_device
if prompt is None:
prompt = [""] * batch_size
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
prompt=prompt,
device=device,
tokenizer_max_length=self.tokenizer_max_length,
template=self.prompt_template_encode,
drop_idx=self.prompt_template_encode_start_idx,
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
def prepare_latents(
self,
image_latents,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
strength=0.25,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, 1, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_latents = strength * noise + (1 - strength) * image_latents
return latents, cond_latents
@staticmethod
def _reorder_image_tokens(image_latents):
image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
image_latents = image_latents.permute(0, 2, 1, 3, 4)
image_latents = image_latents.reshape(
batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
)
image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
return image_latents
@staticmethod
def _restore_image_tokens_order(latents):
"""Restore image tokens order by splitting channels and removing first frame slice."""
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
latents = latents.reshape(
batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
) # B, F*2, C//2, H, W
latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
# Remove first frame slice
latents = latents[:, :, 1:]
return latents
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
image_latents = self._reorder_image_tokens(image_latents)
image_latents = image_latents * self.vae.config.scaling_factor
return image_latents
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
distilled_guidance_scale: Optional[float] = 3.25,
image: Optional[PipelineImageInput] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
sigmas: Optional[List[float]] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
prompt. Ignored when not using guidance.
distilled_guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
ignored.
num_images_per_prompt (`int`, *optional*, defaults to 1):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. process image
if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
image_latents = image
else:
image = self.image_processor.preprocess(image, height, width)
image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
# 3.prepare prompt embeds
if self.guider is not None:
guider = self.guider
else:
# distilled model does not use guidance method, use default guider with enabled=False
guider = AdaptiveProjectedMixGuidance(enabled=False)
requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
if requires_unconditional_embeds:
(
negative_prompt_embeds,
negative_prompt_embeds_mask,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
# 4. Prepare latent variables
latents, cond_latents = self.prepare_latents(
image_latents=image_latents,
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=self.latent_channels,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance (this pipeline only supports guidance-distilled models)
if distilled_guidance_scale is None:
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
guidance = (
torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
* 1000.0
)
if self.attention_kwargs is None:
self._attention_kwargs = {}
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
}
# Step 2: Update guider's internal state for this denoising step
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
# Step 3: Prepare batched model inputs based on the guidance method
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = guider.prepare_inputs(guider_inputs)
# Step 4: Run the denoiser for each batch
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
for guider_state_batch in guider_state:
guider.prepare_models(self.transformer)
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
cond_kwargs = {
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
}
# e.g. "pred_cond"/"pred_uncond"
context_name = getattr(guider_state_batch, guider._identifier_key)
with self.transformer.cache_context(context_name):
# Run denoiser and store noise prediction in this batch
guider_state_batch.noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
guidance=guidance,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
# Cleanup model (e.g., remove hooks)
guider.cleanup_models(self.transformer)
# Step 5: Combine predictions using the guidance method
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
# Continuing the CFG example, the guider receives:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
# ]
# And extracts predictions using the __guidance_identifier__:
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
# Then applies CFG formula:
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
noise_pred = guider(guider_state)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
latents = self._restore_image_tokens_order(latents)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HunyuanImagePipelineOutput(images=image)
@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class HunyuanImagePipelineOutput(BaseOutput):
"""
Output class for HunyuanImage pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@@ -113,7 +113,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
_cut_context=True,
_cut_context=False,
attention_mask: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
):
@@ -173,10 +173,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
)
self.prompt_template_encode_start_idx = 129
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@staticmethod
@@ -386,9 +384,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = len(prompt)
prompt = [prompt_clean(p) for p in prompt]
@@ -749,13 +744,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
self.encode_prompt(
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt(
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
# 4. Prepare timesteps
@@ -787,8 +780,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
negative_text_rope_pos = (
torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
if negative_prompt_cu_seqlens is not None
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
if negative_cu_seqlens is not None
else None
)
@@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
return torch.zeros(
(
batch_size * num_images_per_prompt,
max_sequence_length,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
return torch.zeros(
(
batch_size * num_images_per_prompt,
max_sequence_length,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -33,7 +33,6 @@ from ..utils import (
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_maybe_remap_transformers_class,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
@@ -76,7 +75,6 @@ LOADABLE_CLASSES = {
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
"BaseGuidance": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
@@ -358,11 +356,6 @@ def maybe_raise_or_warn(
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -397,11 +390,6 @@ def simple_get_class_obj(library_name, class_name):
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
return class_obj
@@ -428,10 +416,6 @@ def get_class_obj_and_candidates(
# else we just import it from the library.
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -882,9 +866,6 @@ def load_sub_model(
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
skip_keys = None
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
skip_keys = loaded_sub_model._skip_keys
if needs_offloading_to_cpu:
dispatch_model(
@@ -893,10 +874,9 @@ def load_sub_model(
device_map=device_map,
force_hooks=True,
main_device=0,
skip_keys=skip_keys,
)
else:
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
return loaded_sub_model
-63
View File
@@ -1,63 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_prx"] = ["PRXPipeline"]
# Import T5GemmaEncoder for pipeline loading compatibility
try:
if is_transformers_available():
import transformers
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
# Patch transformers module directly for serialization
if not hasattr(transformers, "T5GemmaEncoder"):
transformers.T5GemmaEncoder = T5GemmaEncoder
except ImportError:
pass
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_output import PRXPipelineOutput
from .pipeline_prx import PRXPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
@@ -1,35 +0,0 @@
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class PRXPipelineOutput(BaseOutput):
"""
Output class for PRX pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

Some files were not shown because too many files have changed in this diff Show More