Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul 6bb2f930d8 styling issues. 2025-10-21 04:23:05 -10:00
620 changed files with 9603 additions and 21159 deletions
+1 -1
View File
@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
+8 -29
View File
@@ -42,39 +42,18 @@ jobs:
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
run: |
echo "$CHANGED_FILES"
ALLOWED_IMAGES=(
diffusers-pytorch-cpu
diffusers-pytorch-cuda
diffusers-pytorch-xformers-cuda
diffusers-pytorch-minimum-cuda
diffusers-doc-builder
)
declare -A IMAGES_TO_BUILD=()
for FILE in $CHANGED_FILES; do
for FILE in $CHANGED_FILES; do
# skip anything that isn't still on disk
if [[ ! -e "$FILE" ]]; then
if [[ ! -f "$FILE" ]]; then
echo "Skipping removed file $FILE"
continue
fi
if [[ "$FILE" == docker/*Dockerfile ]]; then
DOCKER_PATH="${FILE%/Dockerfile}"
DOCKER_TAG=$(basename "$DOCKER_PATH")
echo "Building Docker image for $DOCKER_TAG"
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
fi
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
IMAGES_TO_BUILD["$IMAGE"]=1
fi
done
done
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
echo "No relevant Docker changes detected."
exit 0
fi
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
DOCKER_PATH="docker/${IMAGE}"
echo "Building Docker image for $IMAGE"
docker build -t "$IMAGE" "$DOCKER_PATH"
done
if: steps.file_changes.outputs.all != ''
+1 -1
View File
@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600
+1 -1
View File
@@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install -e .
+1 -1
View File
@@ -26,7 +26,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
+3 -3
View File
@@ -22,7 +22,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
@@ -35,7 +35,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -55,7 +55,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
+3 -3
View File
@@ -24,7 +24,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
@@ -36,7 +36,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -56,7 +56,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install --upgrade pip
@@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
pip install -e .
+1 -1
View File
@@ -14,7 +14,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 50000
+1 -1
View File
@@ -18,7 +18,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
+1 -1
View File
@@ -8,7 +8,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_XET_HIGH_PERFORMANCE: 1
HF_HUB_ENABLE_HF_TRANSFER: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
+1 -1
View File
@@ -47,7 +47,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.8"
- name: Install dependencies
run: |
+1 -1
View File
@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_xet \
hf_transfer \
setuptools==69.5.1 \
bitsandbytes \
torchao \
+1 -1
View File
@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
scipy \
tensorboard \
transformers \
hf_xet
hf_transfer
CMD ["/bin/bash"]
+3 -2
View File
@@ -38,12 +38,13 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
datasets \
hf-doc-builder \
huggingface-hub \
hf_xet \
hf_transfer \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
transformers \
hf_transfer
CMD ["/bin/bash"]
+1 -1
View File
@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_xet
hf_transfer
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
+1 -1
View File
@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_xet
hf_transfer
CMD ["/bin/bash"]
@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_xet
hf_transfer
CMD ["/bin/bash"]
@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_xet \
hf_transfer \
xformers
CMD ["/bin/bash"]
-16
View File
@@ -323,8 +323,6 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
@@ -349,8 +347,6 @@
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuanimage_transformer_2d
title: HunyuanImageTransformer2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
@@ -415,10 +411,6 @@
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuanimage
title: AutoencoderKLHunyuanImage
- local: api/models/autoencoder_kl_hunyuanimage_refiner
title: AutoencoderKLHunyuanImageRefiner
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
@@ -471,8 +463,6 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -529,8 +519,6 @@
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kandinsky5
title: Kandinsky 5
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models
@@ -557,8 +545,6 @@
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
@@ -632,8 +618,6 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hunyuanimage21
title: HunyuanImage2.1
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
@@ -1,32 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImage
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImage
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImage
[[autodoc]] AutoencoderKLHunyuanImage
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -1,32 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLHunyuanImageRefiner
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLHunyuanImageRefiner
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
```
## AutoencoderKLHunyuanImageRefiner
[[autodoc]] AutoencoderKLHunyuanImageRefiner
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# ChromaTransformer2DModel
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
## ChromaTransformer2DModel
@@ -1,30 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HunyuanImageTransformer2DModel
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
The model can be loaded with the following code snippet.
```python
from diffusers import HunyuanImageTransformer2DModel
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HunyuanImageTransformer2DModel
[[autodoc]] HunyuanImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -1,19 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaFiboTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
## BriaFiboTransformer2DModel
[[autodoc]] BriaFiboTransformer2DModel
-45
View File
@@ -1,45 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__
+6 -7
View File
@@ -19,21 +19,20 @@ specific language governing permissions and limitations under the License.
Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found here:
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
> [!TIP]
> Chroma can use all the same optimizations as Flux.
## Inference
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = [
@@ -64,10 +63,10 @@ Then run the following example
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
model_id = "lodestones/Chroma1-HD"
model_id = "lodestones/Chroma"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
@@ -1,152 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HunyuanImage2.1
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
HunyuanImage-2.1 comes in the following variants:
| model type | model id |
|:----------:|:--------:|
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
```
You can inspect the `guider` object:
```py
>>> pipe.guider
AdaptiveProjectedMixGuidance {
"_class_name": "AdaptiveProjectedMixGuidance",
"_diffusers_version": "0.36.0.dev0",
"adaptive_projected_guidance_momentum": -0.5,
"adaptive_projected_guidance_rescale": 10.0,
"adaptive_projected_guidance_scale": 10.0,
"adaptive_projected_guidance_start_step": 5,
"enabled": true,
"eta": 0.0,
"guidance_rescale": 0.0,
"guidance_scale": 3.5,
"start": 0.0,
"stop": 1.0,
"use_original_formulation": false
}
State:
step: None
num_inference_steps: None
timestep: None
count_prepared: 0
enabled: True
num_conditions: 2
momentum_buffer: None
is_apg_enabled: False
is_cfg_enabled: True
```
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained(
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
# Update the guider configuration
pipe.guider = pipe.guider.new(guidance_scale=5.0)
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
image = pipe(
prompt=prompt,
num_inference_steps=50,
height=2048,
width=2048,
).images[0]
image.save("image.png")
```
## HunyuanImage-2.1-Distilled
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
```py
import torch
from diffusers import HunyuanImagePipeline
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = (
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
)
out = pipe(
prompt,
num_inference_steps=8,
distilled_guidance_scale=3.25,
height=2048,
width=2048,
generator=generator,
).images[0]
```
## HunyuanImagePipeline
[[autodoc]] HunyuanImagePipeline
- all
- __call__
## HunyuanImageRefinerPipeline
[[autodoc]] HunyuanImageRefinerPipeline
- all
- __call__
## HunyuanImagePipelineOutput
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
-149
View File
@@ -1,149 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
The model introduces several key innovations:
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
> [!TIP]
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
## Available Models
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
| model_id | Description | Use Cases |
|------------|-------------|-----------|
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
All models are available in 5-second and 10-second video generation versions.
## Kandinsky5T2VPipeline
[[autodoc]] Kandinsky5T2VPipeline
- all
- __call__
## Usage Examples
### Basic Text-to-Video Generation
```python
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### 10 second Models
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
```python
pipe = Kandinsky5T2VPipeline.from_pretrained(
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Set attention backend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
) # <--- Compile with max-autotune-no-cudagraphs
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=241,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
output = pipe(
prompt="A beautiful sunset over mountains",
num_inference_steps=16, # <--- Model is distilled in 16 steps
guidance_scale=1.0, # <--- no CFG
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
## Citation
```bibtex
@misc{kandinsky2025,
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
year = 2025
}
```
-131
View File
@@ -1,131 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# PRX
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("prx_output.png")
```
### Manual Component Loading
Load components individually to customize the pipeline for instance to use quantized models.
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
subfolder="vae",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
)
pipe.to("cuda")
```
## Memory Optimization
For memory-constrained environments:
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PRXPipeline
[[autodoc]] PRXPipeline
- all
- __call__
## PRXPipelineOutput
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
@@ -21,7 +21,6 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
@@ -140,7 +139,6 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
File diff suppressed because it is too large Load Diff
-345
View File
@@ -1,345 +0,0 @@
#!/usr/bin/env python3
"""
Script to convert PRX checkpoint from original codebase to diffusers format.
"""
import argparse
import json
import os
import sys
from dataclasses import asdict, dataclass
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PRXBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
num_heads: int = 28
depth: int = 16
axes_dim: Tuple[int, int] = (32, 32)
theta: int = 10_000
time_factor: float = 1000.0
time_max_period: int = 10_000
@dataclass(frozen=True)
class PRXFlux(PRXBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PRXDCAE(PRXBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PRXFlux()
elif vae_type == "dc-ae":
cfg = PRXDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
config_dict = asdict(cfg)
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
return config_dict
def create_parameter_mapping(depth: int) -> dict:
"""Create mapping from old parameter names to new diffusers names."""
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
# K norm for text tokens moved to attention module
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
# Attention output projection
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
return mapping
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
"""Convert old checkpoint parameters to new diffusers format."""
print("Converting checkpoint parameters...")
mapping = create_parameter_mapping(depth)
converted_state_dict = {}
for key, value in old_state_dict.items():
new_key = key
# Apply specific mappings if needed
if key in mapping:
new_key = mapping[key]
print(f" Mapped: {key} -> {new_key}")
converted_state_dict[new_key] = value
print(f"✓ Converted {len(converted_state_dict)} parameters")
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
# Load old checkpoint
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Handle different checkpoint formats
if isinstance(old_checkpoint, dict):
if "model" in old_checkpoint:
state_dict = old_checkpoint["model"]
elif "state_dict" in old_checkpoint:
state_dict = old_checkpoint["state_dict"]
else:
state_dict = old_checkpoint
else:
state_dict = old_checkpoint
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
# Convert parameter names if needed
model_depth = int(config.get("depth", 16))
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PRXTransformer2DModel...")
transformer = PRXTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
if missing_keys:
print(f"⚠ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠ Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print("✓ All parameters loaded successfully!")
return transformer
def create_scheduler_config(output_path: str, shift: float):
"""Create FlowMatchEulerDiscreteScheduler config."""
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
scheduler_path = os.path.join(output_path, "scheduler")
os.makedirs(scheduler_path, exist_ok=True)
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
json.dump(scheduler_config, f, indent=2)
print("✓ Created scheduler config")
def download_and_save_vae(vae_type: str, output_path: str):
"""Download and save VAE to local directory."""
from diffusers import AutoencoderDC, AutoencoderKL
vae_path = os.path.join(output_path, "vae")
os.makedirs(vae_path, exist_ok=True)
if vae_type == "flux":
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
else: # dc-ae
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
vae.save_pretrained(vae_path)
print(f"✓ Saved VAE to {vae_path}")
def download_and_save_text_encoder(output_path: str):
"""Download and save T5Gemma text encoder and tokenizer."""
from transformers import GemmaTokenizerFast
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
text_encoder_path = os.path.join(output_path, "text_encoder")
tokenizer_path = os.path.join(output_path, "tokenizer")
os.makedirs(text_encoder_path, exist_ok=True)
os.makedirs(tokenizer_path, exist_ok=True)
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
# Extract and save only the encoder
t5gemma_encoder = t5gemma_model.encoder
t5gemma_encoder.save_pretrained(text_encoder_path)
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
tokenizer.save_pretrained(tokenizer_path)
print(f"✓ Saved tokenizer to {tokenizer_path}")
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
"""Create model_index.json for the pipeline."""
if vae_type == "flux":
vae_class = "AutoencoderKL"
else: # dc-ae
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
model_index_path = os.path.join(output_path, "model_index.json")
with open(model_index_path, "w") as f:
json.dump(model_index, f, indent=2)
def main(args):
# Validate inputs
if not os.path.exists(args.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
config = build_config(args.vae_type)
# Create output directory
os.makedirs(args.output_path, exist_ok=True)
print(f"✓ Output directory: {args.output_path}")
# Create transformer from checkpoint
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
# Save transformer
transformer_path = os.path.join(args.output_path, "transformer")
os.makedirs(transformer_path, exist_ok=True)
# Save config
with open(os.path.join(transformer_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)
# Save model weights as safetensors
state_dict = transformer.state_dict()
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
print(f"✓ Saved transformer to {transformer_path}")
# Create scheduler config
create_scheduler_config(args.output_path, args.shift)
download_and_save_vae(args.vae_type, args.output_path)
download_and_save_text_encoder(args.output_path)
# Create model_index.json
create_model_index(args.vae_type, args.resolution, args.output_path)
# Verify the pipeline can be loaded
try:
pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
# Display model info
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
print(f"✓ Transformer parameters: {num_params:,}")
except Exception as e:
print(f"Pipeline verification failed: {e}")
return False
print("Conversion completed successfully!")
print(f"Converted pipeline saved to: {args.output_path}")
print(f"VAE type: {args.vae_type}")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
)
parser.add_argument(
"--vae_type",
type=str,
choices=["flux", "dc-ae"],
required=True,
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
)
parser.add_argument(
"--resolution",
type=int,
choices=[256, 512, 1024],
default=DEFAULT_RESOLUTION,
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
)
parser.add_argument(
"--shift",
type=float,
default=3.0,
help="Shift for the scheduler",
)
args = parser.parse_args()
try:
success = main(args)
if not success:
sys.exit(1)
except Exception as e:
print(f"Conversion failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
+2 -2
View File
@@ -122,7 +122,7 @@ _deps = [
"pytest",
"pytest-timeout",
"pytest-xdist",
"python>=3.9.0",
"python>=3.8.0",
"ruff==0.9.10",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
@@ -287,7 +287,7 @@ setup(
packages=find_packages("src"),
package_data={"diffusers": ["py.typed"]},
include_package_data=True,
python_requires=">=3.10.0",
python_requires=">=3.8.0",
install_requires=list(install_requires),
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
-22
View File
@@ -149,9 +149,7 @@ else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AdaptiveProjectedMixGuidance",
"AutoGuidance",
"BaseGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
@@ -186,8 +184,6 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
@@ -198,7 +194,6 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
@@ -221,7 +216,6 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
@@ -240,7 +234,6 @@ else:
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"PRXTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
@@ -431,7 +424,6 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
@@ -469,8 +461,6 @@ else:
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanImagePipeline",
"HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
@@ -529,7 +519,6 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"PRXPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
@@ -858,9 +847,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .guiders import (
AdaptiveProjectedGuidance,
AdaptiveProjectedMixGuidance,
AutoGuidance,
BaseGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
@@ -891,8 +878,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -903,7 +888,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
@@ -926,7 +910,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
@@ -945,7 +928,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
@@ -1106,7 +1088,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
@@ -1144,8 +1125,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanImagePipeline,
HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
@@ -1204,7 +1183,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PRXPipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
+12 -12
View File
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, List
from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME
@@ -33,13 +33,13 @@ class PipelineCallback(ConfigMixin):
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
@property
def tensor_inputs(self) -> list[str]:
def tensor_inputs(self) -> List[str]:
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
@@ -49,14 +49,14 @@ class MultiPipelineCallbacks:
provides a unified interface for calling all of them.
"""
def __init__(self, callbacks: list[PipelineCallback]):
def __init__(self, callbacks: List[PipelineCallback]):
self.callbacks = callbacks
@property
def tensor_inputs(self) -> list[str]:
def tensor_inputs(self) -> List[str]:
return [input for callback in self.callbacks for input in callback.tensor_inputs]
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
"""
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
"""
@@ -76,7 +76,7 @@ class SDCFGCutoffCallback(PipelineCallback):
tensor_inputs = ["prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -109,7 +109,7 @@ class SDXLCFGCutoffCallback(PipelineCallback):
"add_time_ids",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -152,7 +152,7 @@ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"image",
]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -195,7 +195,7 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
tensor_inputs = []
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
@@ -219,7 +219,7 @@ class SD3CFGCutoffCallback(PipelineCallback):
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
+20 -18
View File
@@ -24,7 +24,7 @@ import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
@@ -94,10 +94,10 @@ class ConfigMixin:
Class attributes:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`list[str]`) -- A list of attributes that should not be saved in the config (should be
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
- **_deprecated_kwargs** (`list[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
"""
@@ -143,7 +143,7 @@ class ConfigMixin:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def save_config(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
[`~ConfigMixin.from_config`] class method.
@@ -155,7 +155,7 @@ class ConfigMixin:
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
@@ -189,13 +189,13 @@ class ConfigMixin:
@classmethod
def from_config(
cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs
) -> Self | tuple[Self, dict[str, Any]]:
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
r"""
Instantiate a Python class from a config dictionary.
Parameters:
config (`dict[str, Any]`):
config (`Dict[str, Any]`):
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
files of compatible classes.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
@@ -292,11 +292,11 @@ class ConfigMixin:
@validate_hf_hub_args
def load_config(
cls,
pretrained_model_name_or_path: str | os.PathLike,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Load a model or scheduler configuration.
@@ -315,7 +315,7 @@ class ConfigMixin:
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
@@ -352,7 +352,7 @@ class ConfigMixin:
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
@@ -563,7 +563,9 @@ class ConfigMixin:
return init_dict, unused_kwargs, hidden_config_dict
@classmethod
def _dict_from_json_file(cls, json_file: str | os.PathLike, dduf_entries: Optional[dict[str, DDUFEntry]] = None):
def _dict_from_json_file(
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
@@ -575,12 +577,12 @@ class ConfigMixin:
return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> dict[str, Any]:
def config(self) -> Dict[str, Any]:
"""
Returns the config of the class as a frozen dictionary
Returns:
`dict[str, Any]`: Config of the class.
`Dict[str, Any]`: Config of the class.
"""
return self._internal_dict
@@ -623,7 +625,7 @@ class ConfigMixin:
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str | os.PathLike):
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save the configuration instance's parameters to a JSON file.
@@ -635,7 +637,7 @@ class ConfigMixin:
writer.write(self.to_json_string())
@classmethod
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: dict[str, DDUFEntry]):
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
@@ -754,7 +756,7 @@ class LegacyConfigMixin(ConfigMixin):
"""
@classmethod
def from_config(cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs):
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
# To prevent dependency import problem.
from .models.model_loading_utils import _fetch_remapped_cls_from_config
+1 -1
View File
@@ -29,7 +29,7 @@ deps = {
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.9.0",
"python": "python>=3.8.0",
"ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
+13 -3
View File
@@ -14,18 +14,28 @@
from typing import Union
from ..utils import is_torch_available, logging
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
]
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -67,9 +65,8 @@ class AdaptiveProjectedGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
@@ -79,14 +76,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -150,44 +152,6 @@ class MomentumBuffer:
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def normalized_guidance(
pred_cond: torch.Tensor,
@@ -1,284 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedMixGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
used, and momentum buffer is updated).
enabled (`bool`, defaults to `True`):
Whether this guidance is enabled.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 3.5,
guidance_rescale: float = 0.0,
adaptive_projected_guidance_scale: float = 10.0,
adaptive_projected_guidance_momentum: float = -0.5,
adaptive_projected_guidance_rescale: float = 10.0,
eta: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
adaptive_projected_guidance_start_step: int = 5,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# no guidance
if not self._is_cfg_enabled():
pred = pred_cond
# CFG + update momentum buffer
elif not self._is_apg_enabled():
if self.momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
# CFG + update momentum buffer
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
# APG
elif self._is_apg_enabled():
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.adaptive_projected_guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled() or self._is_cfg_enabled():
num_conditions += 1
return num_conditions
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
if not self._is_cfg_enabled():
return False
is_within_range = False
if self._step is not None:
is_within_range = self._step > self.adaptive_projected_guidance_start_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
else:
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
return is_within_range and not is_close
def get_state(self):
state = super().get_state()
state["momentum_buffer"] = self.momentum_buffer
state["is_apg_enabled"] = self._is_apg_enabled()
state["is_cfg_enabled"] = self._is_cfg_enabled()
return state
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def __repr__(self) -> str:
"""
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
"""
if isinstance(self.running_average, torch.Tensor):
shape = tuple(self.running_average.shape)
# Calculate statistics
with torch.no_grad():
stats = {
"mean": self.running_average.mean().item(),
"std": self.running_average.std().item(),
"min": self.running_average.min().item(),
"max": self.running_average.max().item(),
}
# Get a slice (max 3 elements per dimension)
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
sliced_data = self.running_average[slice_indices]
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
slice_str = str(sliced_data.detach().float().cpu().numpy())
if len(slice_str) > 200: # Truncate if too long
slice_str = slice_str[:200] + "..."
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
return (
f"MomentumBuffer(\n"
f" momentum={self.momentum},\n"
f" shape={shape},\n"
f" stats=[{stats_str}],\n"
f" slice={slice_str}\n"
f")"
)
else:
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
def update_momentum_buffer(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
momentum_buffer: Optional[MomentumBuffer] = None,
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
if momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
diff = momentum_buffer.running_average
else:
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
+14 -12
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
@@ -38,10 +36,10 @@ class AutoGuidance(BaseGuidance):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
auto_guidance_layers (`int` or `list[int]`, *optional*):
auto_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided.
auto_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
dropout (`float`, *optional*):
@@ -67,16 +65,15 @@ class AutoGuidance(BaseGuidance):
def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[int | list[int]] = None,
auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
@@ -135,11 +132,16 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -29,50 +27,43 @@ if TYPE_CHECKING:
class ClassifierFreeGuidance(BaseGuidance):
"""
Implements Classifier-Free Guidance (CFG) for diffusion models.
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
Reference: https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
proposes scaling and shifting the conditional distribution based on the difference between conditional and
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
unconditional data, then combining predictions during inference. This allows trading off between quality (high
guidance) and diversity (low guidance).
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
**Two CFG Formulations:**
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
1. **Original formulation** (from paper):
```
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
```
Moves conditional predictions further from unconditional ones.
2. **Diffusers-native formulation** (default, from Imagen paper):
```
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
```
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
quality", "watermarks"). Equivalent in theory but more intuitive.
Use `use_original_formulation=True` to switch to the original formulation.
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
may reduce quality. Typical range: 1.0-20.0.
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
to 1.0 (full rescaling).
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
diffusers-native formulation from the Imagen paper.
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
steps.
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
steps.
enabled (`bool`, defaults to `True`):
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@@ -85,19 +76,23 @@ class ClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -70,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
# YiYi Notes: add default behavior for self._enabled == False
if not self._enabled:
pred = pred_cond
elif self._step < self.zero_init_steps:
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -39,7 +37,7 @@ else:
build_laplacian_pyramid_func = None
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
(Algorithm 2).
@@ -60,7 +58,7 @@ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -
return v0_parallel, v0_orthogonal
def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor:
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
"""
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
(Algorithm 2).
@@ -101,19 +99,19 @@ class FrequencyDecoupledGuidance(BaseGuidance):
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`):
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
descending order).
guidance_rescale (`float` or `list[float]`, defaults to `0.0`):
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
`guidance_scales`.
parallel_weights (`float` or `list[float]`, *optional*):
parallel_weights (`float` or `List[float]`, *optional*):
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
@@ -122,10 +120,10 @@ class FrequencyDecoupledGuidance(BaseGuidance):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float` or `list[float]`, defaults to `0.0`):
start (`float` or `List[float]`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
should be the same length as `guidance_scales`.
stop (`float` or `list[float]`, defaults to `1.0`):
stop (`float` or `List[float]`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
should be the same length as `guidance_scales`.
guidance_rescale_space (`str`, defaults to `"data"`):
@@ -143,15 +141,14 @@ class FrequencyDecoupledGuidance(BaseGuidance):
@register_to_config
def __init__(
self,
guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
guidance_rescale: float | list[float] | tuple[float] = 0.0,
parallel_weights: Optional[float | list[float] | tuple[float]] = None,
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
use_original_formulation: bool = False,
start: float | list[float] | tuple[float] = 0.0,
stop: float | list[float] | tuple[float] = 1.0,
start: Union[float, List[float], Tuple[float]] = 0.0,
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
enabled: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
@@ -163,7 +160,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop, enabled)
super().__init__(min_start, max_stop)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
@@ -220,11 +217,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})"
)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
+57 -92
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
@@ -42,18 +40,14 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
logger.warning(
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
def __init__(self, start: float = 0.0, stop: float = 1.0):
self._start = start
self._stop = stop
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: dict[str, str | tuple[str, str]] = None
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
if not (0.0 <= start < 1.0):
@@ -66,31 +60,6 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def new(self, **kwargs):
"""
Creates a copy of this guider instance, optionally with modified configuration parameters.
Args:
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
returns an exact copy with the same configuration.
Returns:
A new guider instance with the same (or updated) configuration.
Example:
```python
# Create a CFG guider
guider = ClassifierFreeGuidance(guidance_scale=3.5)
# Create an exact copy
same_guider = guider.new()
# Create a copy with different start step, keeping other config the same
new_guider = guider.new(guidance_scale=5)
```
"""
return self.__class__.from_config(self.config, **kwargs)
def disable(self):
self._enabled = False
@@ -103,52 +72,42 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep
self._count_prepared = 0
def get_state(self) -> dict[str, Any]:
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
"""
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
the __repr__ method. Returns:
`Dict[str, Any]`: A dictionary containing the current state variables including:
- step: Current inference step
- num_inference_steps: Total number of inference steps
- timestep: Current timestep tensor
- count_prepared: Number of times prepare_models has been called
- enabled: Whether the guidance is enabled
- num_conditions: Number of conditions
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with a
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
conditional data identifier and the second element must be the unconditional data identifier or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
state = {
"step": self._step,
"num_inference_steps": self._num_inference_steps,
"timestep": self._timestep,
"count_prepared": self._count_prepared,
"enabled": self._enabled,
"num_conditions": self.num_conditions,
}
return state
def __repr__(self) -> str:
"""
Returns a string representation of the guidance object including both config and current state.
"""
# Get ConfigMixin's __repr__
str_repr = super().__repr__()
# Get current state
state = self.get_state()
# Format each state variable on its own line with indentation
state_lines = []
for k, v in state.items():
# Convert value to string and handle multi-line values
v_str = str(v)
if "\n" in v_str:
# For multi-line values (like MomentumBuffer), indent subsequent lines
v_lines = v_str.split("\n")
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
state_lines.append(f" {k}: {v_str}")
state_str = "\n".join(state_lines)
return f"{str_repr}\nState:\n{state_str}"
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = (
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
@@ -165,10 +124,10 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
pass
def prepare_inputs(self, data: "BlockState") -> list["BlockState"]:
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def __call__(self, data: list["BlockState"]) -> Any:
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
if len(data) != self.num_conditions:
@@ -196,7 +155,8 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod
def _prepare_batch(
cls,
data: dict[str, tuple[torch.Tensor, torch.Tensor]],
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
@@ -205,7 +165,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`dict[str, Union[str, tuple[str, str]]]`):
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
@@ -222,16 +182,21 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError(
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
)
data_batch = {}
for key, value in data.items():
for key, value in input_fields.items():
try:
if isinstance(value, torch.Tensor):
data_batch[key] = value
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = value[tuple_index]
data_batch[key] = getattr(data, value[tuple_index])
else:
raise ValueError(f"Invalid value type: {type(value)}")
except ValueError:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@@ -240,7 +205,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[str | os.PathLike] = None,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
@@ -267,7 +232,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
@@ -297,7 +262,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a guider configuration object to a directory so that it can be reloaded using the
[`~BaseGuidance.from_pretrained`] class method.
@@ -309,7 +274,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
@@ -60,10 +58,10 @@ class PerturbedAttentionGuidance(BaseGuidance):
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
perturbed_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
perturbed_guidance_layers (`int` or `list[int]`, *optional*):
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
If not provided, `perturbed_guidance_config` must be provided.
perturbed_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
@@ -94,15 +92,14 @@ class PerturbedAttentionGuidance(BaseGuidance):
perturbed_guidance_scale: float = 2.8,
perturbed_guidance_start: float = 0.01,
perturbed_guidance_stop: float = 0.2,
perturbed_guidance_layers: Optional[int | list[int]] = None,
perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
@@ -171,7 +168,12 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -184,8 +186,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
+14 -12
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
@@ -66,11 +64,11 @@ class SkipLayerGuidance(BaseGuidance):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `list[int]`, *optional*):
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
@@ -96,15 +94,14 @@ class SkipLayerGuidance(BaseGuidance):
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[int | list[int]] = None,
skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
@@ -167,7 +164,12 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -180,8 +182,8 @@ class SkipLayerGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -56,11 +54,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `list[int]`, *optional*):
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
Diffusion 3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `list[SmoothedEnergyGuidanceConfig]`, *optional*):
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
@@ -88,15 +86,14 @@ class SmoothedEnergyGuidance(BaseGuidance):
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[int | list[int]] = None,
seg_guidance_config: SmoothedEnergyGuidanceConfig | list[SmoothedEnergyGuidanceConfig] = None,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
@@ -156,7 +153,12 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -169,8 +171,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -60,19 +58,23 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
+2 -32
View File
@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Type
from typing import Any, Callable, Dict, Type
@dataclass
@@ -28,7 +28,7 @@ class TransformerBlockMetadata:
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: dict[str, int] = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
@@ -108,7 +108,6 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
@@ -150,14 +149,6 @@ def _register_attention_processors_metadata():
),
)
# HunyuanImageAttnProcessor
AttentionProcessorRegistry.register(
model_class=HunyuanImageAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -171,10 +162,6 @@ def _register_transformer_blocks_metadata():
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_hunyuanimage import (
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -296,22 +283,6 @@ def _register_transformer_blocks_metadata():
),
)
# HunyuanImage2.1
TransformerBlockRegistry.register(
model_class=HunyuanImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanImageSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -337,5 +308,4 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
+6 -6
View File
@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Type
from typing import Dict, List, Type, Union
import torch
@@ -42,7 +42,7 @@ _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: dict[str, int] = None
cached_parameter_indices: Dict[str, int] = None
_cls: Type = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
@@ -78,7 +78,7 @@ class ModuleForwardMetadata:
def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
plan: dict[str, ContextParallelModelPlan],
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
@@ -107,7 +107,7 @@ def apply_context_parallel(
registry.register_hook(hook, hook_name)
def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
@@ -272,13 +272,13 @@ class EquipartitionSharder:
return tensor
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+22 -22
View File
@@ -14,7 +14,7 @@
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional, Tuple
import torch
@@ -60,7 +60,7 @@ class FasterCacheConfig:
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
states again.
spatial_attention_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 681)`):
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
The timestep range within which the spatial attention computation can be skipped without a significant loss
in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
@@ -68,17 +68,17 @@ class FasterCacheConfig:
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
process.
temporal_attention_timestep_skip_range (`tuple[float, float]`, *optional*, defaults to `None`):
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
The timestep range within which the temporal attention computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
timestep 0).
low_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(99, 901)`):
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
The timestep range within which the low frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
high_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(-1, 301)`):
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
The timestep range within which the high frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
@@ -92,15 +92,15 @@ class FasterCacheConfig:
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
computing the new unconditional branch states again.
unconditional_batch_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 641)`):
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
The timestep range within which the unconditional branch computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound.
spatial_attention_block_identifiers (`tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
temporal_attention_block_identifiers (`tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
@@ -123,7 +123,7 @@ class FasterCacheConfig:
is_guidance_distilled (`bool`, defaults to `False`):
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
_unconditional_conditional_input_kwargs_identifiers (`list[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
@@ -135,12 +135,12 @@ class FasterCacheConfig:
spatial_attention_block_skip_range: int = 2
temporal_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
temporal_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
low_frequency_weight_update_timestep_range: tuple[int, int] = (99, 901)
high_frequency_weight_update_timestep_range: tuple[int, int] = (-1, 301)
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
# 1 and 2 as mentioned in Equation 11 of the paper
alpha_low_frequency: float = 1.1
@@ -148,10 +148,10 @@ class FasterCacheConfig:
# n as described in CFG-Cache explanation in the paper - dependent on the model
unconditional_batch_skip_range: int = 5
unconditional_batch_timestep_skip_range: tuple[int, int] = (-1, 641)
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
attention_weight_callback: Callable[[torch.nn.Module], float] = None
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
@@ -162,7 +162,7 @@ class FasterCacheConfig:
current_timestep_callback: Callable[[], int] = None
_unconditional_conditional_input_kwargs_identifiers: list[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
def __repr__(self) -> str:
return (
@@ -209,7 +209,7 @@ class FasterCacheBlockState:
def __init__(self) -> None:
self.iteration: int = 0
self.batch_size: int = None
self.cache: tuple[torch.Tensor, torch.Tensor] = None
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
def reset(self):
self.iteration = 0
@@ -223,10 +223,10 @@ class FasterCacheDenoiserHook(ModelHook):
def __init__(
self,
unconditional_batch_skip_range: int,
unconditional_batch_timestep_skip_range: tuple[int, int],
unconditional_batch_timestep_skip_range: Tuple[int, int],
tensor_format: str,
is_guidance_distilled: bool,
uncond_cond_input_kwargs_identifiers: list[str],
uncond_cond_input_kwargs_identifiers: List[str],
current_timestep_callback: Callable[[], int],
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
@@ -252,7 +252,7 @@ class FasterCacheDenoiserHook(ModelHook):
return module
@staticmethod
def _get_cond_input(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
# followed by conditional inputs.
_, cond = input.chunk(2, dim=0)
@@ -371,7 +371,7 @@ class FasterCacheBlockHook(ModelHook):
def __init__(
self,
block_skip_range: int,
timestep_skip_range: tuple[int, int],
timestep_skip_range: Tuple[int, int],
is_guidance_distilled: bool,
weight_callback: Callable[[torch.nn.Module], float],
current_timestep_callback: Callable[[], int],
+3 -2
View File
@@ -13,6 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
@@ -52,9 +53,9 @@ class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
+13 -13
View File
@@ -17,7 +17,7 @@ import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Set
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
import torch
@@ -58,21 +58,21 @@ class GroupOffloadingConfig:
low_cpu_mem_usage: bool
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[torch.cuda.Stream | torch.Stream] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
class ModuleGroup:
def __init__(
self,
modules: list[torch.nn.Module],
modules: List[torch.nn.Module],
offload_device: torch.device,
onload_device: torch.device,
offload_leader: torch.nn.Module,
onload_leader: Optional[torch.nn.Module] = None,
parameters: Optional[list[torch.nn.Parameter]] = None,
buffers: Optional[list[torch.Tensor]] = None,
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: torch.cuda.Stream | torch.Stream | None = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
@@ -340,7 +340,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(self):
self.execution_order: list[tuple[str, torch.nn.Module]] = []
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
@@ -444,9 +444,9 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading(
module: torch.nn.Module,
onload_device: str | torch.device,
offload_device: str | torch.device = torch.device("cpu"),
offload_type: str | GroupOffloadingType = "block_level",
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
@@ -787,7 +787,7 @@ def _apply_lazy_group_offloading_hook(
def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> list[torch.nn.Parameter]:
) -> List[torch.nn.Parameter]:
parameters = []
for name, parameter in module.named_parameters():
has_parent_with_group_offloading = False
@@ -805,7 +805,7 @@ def _gather_parameters_with_no_group_offloading_parent(
def _gather_buffers_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> list[torch.Tensor]:
) -> List[torch.Tensor]:
buffers = []
for name, buffer in module.named_buffers():
has_parent_with_group_offloading = False
@@ -821,7 +821,7 @@ def _gather_buffers_with_no_group_offloading_parent(
return buffers
def _find_parent_module_in_module_dict(name: str, module_dict: dict[str, torch.nn.Module]) -> str:
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
+6 -6
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import functools
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple
import torch
@@ -86,19 +86,19 @@ class ModelHook:
"""
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> tuple[tuple[Any], dict[str, Any]]:
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`tuple[Any]`):
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`dict[Str, Any]`):
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`tuple[tuple[Any], dict[Str, Any]]`:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs
@@ -168,7 +168,7 @@ class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()
self.hooks: dict[str, ModelHook] = {}
self.hooks: Dict[str, ModelHook] = {}
self._module_ref = module_ref
self._hook_order = []
+3 -3
View File
@@ -14,7 +14,7 @@
import math
from dataclasses import asdict, dataclass
from typing import Callable, Optional
from typing import Callable, List, Optional
import torch
@@ -43,7 +43,7 @@ class LayerSkipConfig:
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`list[int]`):
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
@@ -63,7 +63,7 @@ class LayerSkipConfig:
skipped layers are fully retained, which is equivalent to not skipping any layers.
"""
indices: list[int]
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
+7 -7
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import re
from typing import Optional, Type
from typing import Optional, Tuple, Type, Union
import torch
@@ -102,8 +102,8 @@ def apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: str | tuple[str, ...] = "auto",
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
@@ -137,12 +137,12 @@ def apply_layerwise_casting(
The dtype to cast the module to before/after the forward pass for storage.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`tuple[str, ...]`, defaults to `"auto"`):
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
instead of its internal submodules.
skip_modules_classes (`tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
@@ -169,8 +169,8 @@ def _apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[tuple[str, ...]] = None,
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
@@ -14,7 +14,7 @@
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Tuple, Union
import torch
@@ -54,20 +54,20 @@ class PyramidAttentionBroadcastConfig:
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be reused) before computing the new attention states again.
spatial_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
temporal_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the temporal attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
cross_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`tuple[str, ...]`):
spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`tuple[str, ...]`):
temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`tuple[str, ...]`):
cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
@@ -75,13 +75,13 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: tuple[int, int] = (100, 800)
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
@@ -141,7 +141,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
_is_stateful = True
def __init__(
self, timestep_skip_range: tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
) -> None:
super().__init__()
@@ -288,8 +288,8 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
def _apply_pyramid_attention_broadcast_hook(
module: Attention | MochiAttention,
timestep_skip_range: tuple[int, int],
module: Union[Attention, MochiAttention],
timestep_skip_range: Tuple[int, int],
block_skip_range: int,
current_timestep_callback: Callable[[], int],
):
@@ -299,7 +299,7 @@ def _apply_pyramid_attention_broadcast_hook(
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
timestep_skip_range (`tuple[int, int]`):
timestep_skip_range (`Tuple[int, int]`):
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
skipped if the current timestep is within the specified range.
block_skip_range (`int`):
@@ -14,7 +14,7 @@
import math
from dataclasses import asdict, dataclass
from typing import Optional
from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -35,21 +35,21 @@ class SmoothedEnergyGuidanceConfig:
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`list[int]`):
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
provide the correct fqn.
_query_proj_identifiers (`list[str]`, defaults to `None`):
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
`None`, `to_q` is used by default.
"""
indices: list[int]
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: list[str] = None
_query_proj_identifiers: List[str] = None
def to_dict(self):
return asdict(self)
+2 -2
View File
@@ -21,8 +21,8 @@ def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
module_list_with_transformer_blocks = []
for name, submodule in module.named_modules():
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
is_ModuleList = isinstance(submodule, torch.nn.ModuleList)
if name_endswith_identifier and is_ModuleList:
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
if name_endswith_identifier and is_modulelist:
module_list_with_transformer_blocks.append((name, submodule))
return module_list_with_transformer_blocks
+48 -43
View File
@@ -14,7 +14,7 @@
import math
import warnings
from typing import Optional
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
@@ -26,9 +26,14 @@ from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
PipelineImageInput = (
PIL.Image.Image | np.ndarray | torch.Tensor | list[PIL.Image.Image] | list[np.ndarray] | list[torch.Tensor]
)
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.Tensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.Tensor],
]
PipelineDepthInput = PipelineImageInput
@@ -63,7 +68,7 @@ def is_valid_image_imagelist(images):
- A list of valid images.
Args:
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, list]`):
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
images.
@@ -126,7 +131,7 @@ class VaeImageProcessor(ConfigMixin):
)
@staticmethod
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a numpy image or a batch of images to a PIL image.
@@ -135,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
The image array to convert to PIL format.
Returns:
`list[PIL.Image.Image]`:
`List[PIL.Image.Image]`:
A list of PIL images.
"""
if images.ndim == 3:
@@ -150,12 +155,12 @@ class VaeImageProcessor(ConfigMixin):
return pil_images
@staticmethod
def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `list[PIL.Image.Image]`):
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
@@ -205,7 +210,7 @@ class VaeImageProcessor(ConfigMixin):
return images
@staticmethod
def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Normalize an image array to [-1,1].
@@ -220,7 +225,7 @@ class VaeImageProcessor(ConfigMixin):
return 2.0 * images - 1.0
@staticmethod
def denormalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Denormalize an image array to [0,1].
@@ -462,11 +467,11 @@ class VaeImageProcessor(ConfigMixin):
def resize(
self,
image: PIL.Image.Image | np.ndarray | torch.Tensor,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: int,
width: int,
resize_mode: str = "default", # "default", "fill", "crop"
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize image.
@@ -539,7 +544,7 @@ class VaeImageProcessor(ConfigMixin):
return image
def _denormalize_conditionally(
self, images: torch.Tensor, do_denormalize: Optional[list[bool]] = None
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
) -> torch.Tensor:
r"""
Denormalize a batch of images based on a condition list.
@@ -547,7 +552,7 @@ class VaeImageProcessor(ConfigMixin):
Args:
images (`torch.Tensor`):
The input image tensor.
do_denormalize (`Optional[list[bool]`, *optional*, defaults to `None`):
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
value of `do_normalize` in the `VaeImageProcessor` config.
"""
@@ -560,10 +565,10 @@ class VaeImageProcessor(ConfigMixin):
def get_default_height_width(
self,
image: PIL.Image.Image | np.ndarray | torch.Tensor,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> tuple[int, int]:
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
@@ -578,7 +583,7 @@ class VaeImageProcessor(ConfigMixin):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`tuple[int, int]`:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
"""
@@ -611,7 +616,7 @@ class VaeImageProcessor(ConfigMixin):
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: str = "default", # "default", "fill", "crop"
crops_coords: Optional[tuple[int, int, int, int]] = None,
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""
Preprocess the image input.
@@ -633,7 +638,7 @@ class VaeImageProcessor(ConfigMixin):
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
crops_coords (`list[tuple[int, int, int, int]]`, *optional*, defaults to `None`):
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
@@ -740,8 +745,8 @@ class VaeImageProcessor(ConfigMixin):
self,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[list[bool]] = None,
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.
@@ -750,7 +755,7 @@ class VaeImageProcessor(ConfigMixin):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
@@ -791,7 +796,7 @@ class VaeImageProcessor(ConfigMixin):
mask: PIL.Image.Image,
init_image: PIL.Image.Image,
image: PIL.Image.Image,
crop_coords: Optional[tuple[int, int, int, int]] = None,
crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image:
r"""
Applies an overlay of the mask and the inpainted image on the original image.
@@ -803,7 +808,7 @@ class VaeImageProcessor(ConfigMixin):
The original image to which the overlay is applied.
image (`PIL.Image.Image`):
The image to overlay onto the original.
crop_coords (`tuple[int, int, int, int]`, *optional*):
crop_coords (`Tuple[int, int, int, int]`, *optional*):
Coordinates to crop the image. If provided, the image will be cropped accordingly.
Returns:
@@ -886,7 +891,7 @@ class InpaintProcessor(ConfigMixin):
height: int = None,
width: int = None,
padding_mask_crop: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess the image and mask.
"""
@@ -941,8 +946,8 @@ class InpaintProcessor(ConfigMixin):
output_type: str = "pil",
original_image: Optional[PIL.Image.Image] = None,
original_mask: Optional[PIL.Image.Image] = None,
crops_coords: Optional[tuple[int, int, int, int]] = None,
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
"""
Postprocess the image, optionally apply mask overlay
"""
@@ -993,7 +998,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
super().__init__()
@staticmethod
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a NumPy image or a batch of images to a list of PIL images.
@@ -1002,7 +1007,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
The input NumPy array of images, which can be a single image or a batch.
Returns:
`list[PIL.Image.Image]`:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy array.
"""
if images.ndim == 3:
@@ -1017,12 +1022,12 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return pil_images
@staticmethod
def depth_pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`Union[list[PIL.Image.Image], PIL.Image.Image]`):
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
The input image or list of images to be converted.
Returns:
@@ -1037,7 +1042,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return images
@staticmethod
def rgblike_to_depthmap(image: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.
@@ -1051,7 +1056,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
def numpy_to_depth(self, images: np.ndarray) -> list[PIL.Image.Image]:
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Convert a NumPy depth image or a batch of images to a list of PIL images.
@@ -1060,7 +1065,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
The input NumPy array of depth images, which can be a single image or a batch.
Returns:
`list[PIL.Image.Image]`:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy depth images.
"""
if images.ndim == 3:
@@ -1083,8 +1088,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
self,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[list[bool]] = None,
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.
@@ -1093,7 +1098,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
@@ -1131,8 +1136,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def preprocess(
self,
rgb: torch.Tensor | PIL.Image.Image | np.ndarray,
depth: torch.Tensor | PIL.Image.Image | np.ndarray,
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
target_res: Optional[int] = None,
@@ -1153,7 +1158,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
Target resolution for resizing the images. If specified, overrides height and width.
Returns:
`tuple[torch.Tensor, torch.Tensor]`:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing the processed RGB and depth images as PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
@@ -1391,7 +1396,7 @@ class PixArtImageProcessor(VaeImageProcessor):
)
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> tuple[int, int]:
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
r"""
Returns the binned height and width based on the aspect ratio.
@@ -1401,7 +1406,7 @@ class PixArtImageProcessor(VaeImageProcessor):
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
Returns:
`tuple[int, int]`: The closest binned height and width.
`Tuple[int, int]`: The closest binned height and width.
"""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+31 -31
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
@@ -57,15 +57,15 @@ class IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
subfolder: str | list[str],
weight_name: str | list[str],
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -74,10 +74,10 @@ class IPAdapterMixin:
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `list[str]`):
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `list[str]`):
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
@@ -94,7 +94,7 @@ class IPAdapterMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -358,14 +358,14 @@ class ModularIPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
subfolder: str | list[str],
weight_name: str | list[str],
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -374,10 +374,10 @@ class ModularIPAdapterMixin:
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `list[str]`):
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `list[str]`):
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
@@ -387,7 +387,7 @@ class ModularIPAdapterMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -608,9 +608,9 @@ class FluxIPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
weight_name: str | list[str],
subfolder: Optional[str | list[str]] = "",
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
@@ -618,7 +618,7 @@ class FluxIPAdapterMixin:
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -627,10 +627,10 @@ class FluxIPAdapterMixin:
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `list[str]`):
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `list[str]`):
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
@@ -647,7 +647,7 @@ class FluxIPAdapterMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -797,13 +797,13 @@ class FluxIPAdapterMixin:
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: float | list[float] | list[list[float]]):
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `list[float]`
length match the number of blocks, it is repeated for each IP adapter. `list[list[float]]` must match the
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
@@ -823,18 +823,18 @@ class FluxIPAdapterMixin:
```
"""
scale_type = int | float
scale_type = Union[int, float]
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
num_layers = self.transformer.config.num_layers
# Single value for all layers of all IP-Adapters
if isinstance(scale, scale_type):
scale = [scale for _ in range(num_ip_adapters)]
# list of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, list[scale_type]) and num_ip_adapters == 1:
# List of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
scale = [scale]
# Invalid scale type
elif not _is_valid_type(scale, list[scale_type | list[scale_type]]):
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
if len(scale) != num_ip_adapters:
@@ -918,7 +918,7 @@ class SD3IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str = "ip-adapter.safetensors",
subfolder: Optional[str] = None,
image_encoder_folder: Optional[str] = "image_encoder",
@@ -953,7 +953,7 @@ class SD3IPAdapterMixin:
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
+27 -27
View File
@@ -17,7 +17,7 @@ import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
@@ -77,7 +77,7 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adap
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
"""
merge_kwargs = {"safe_merge": safe_fusing}
@@ -116,20 +116,20 @@ def unfuse_text_encoder_lora(text_encoder):
def set_adapters_for_text_encoder(
adapter_names: list[str] | str,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: Optional[float | list[float] | list[None]] = None,
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
Args:
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`list[float]`, *optional*):
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if text_encoder is None:
@@ -535,10 +535,10 @@ class LoraBaseMixin:
def fuse_lora(
self,
components: list[str] = [],
components: List[str] = [],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -547,12 +547,12 @@ class LoraBaseMixin:
> [!WARNING] > This is an experimental API.
Args:
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`list[str]`, *optional*):
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
@@ -619,7 +619,7 @@ class LoraBaseMixin:
self._merged_adapters = self._merged_adapters | merged_adapter_names
def unfuse_lora(self, components: list[str] = [], **kwargs):
def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -627,7 +627,7 @@ class LoraBaseMixin:
> [!WARNING] > This is an experimental API.
Args:
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
@@ -674,16 +674,16 @@ class LoraBaseMixin:
def set_adapters(
self,
adapter_names: list[str] | str,
adapter_weights: Optional[float | Dict | list[float] | list[Dict]] = None,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
"""
Set the currently active adapters for use in the pipeline.
Args:
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[list[float], float]`, *optional*):
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -835,12 +835,12 @@ class LoraBaseMixin:
elif issubclass(model.__class__, PreTrainedModel):
enable_lora_for_text_encoder(model)
def delete_adapters(self, adapter_names: list[str] | str):
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the pipeline.
Args:
adapter_names (`Union[list[str], str]`):
adapter_names (`Union[List[str], str]`):
The names of the adapters to delete.
Example:
@@ -873,7 +873,7 @@ class LoraBaseMixin:
for adapter_name in adapter_names:
delete_adapter_layers(model, adapter_name)
def get_active_adapters(self) -> list[str]:
def get_active_adapters(self) -> List[str]:
"""
Gets the list of the current active adapters.
@@ -906,7 +906,7 @@ class LoraBaseMixin:
return active_adapters
def get_list_adapters(self) -> dict[str, list[str]]:
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Gets the current list of all available adapters in the pipeline.
"""
@@ -928,7 +928,7 @@ class LoraBaseMixin:
return set_adapters
def set_lora_device(self, adapter_names: list[str], device: torch.device | str | int) -> None:
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
@@ -955,8 +955,8 @@ class LoraBaseMixin:
```
Args:
adapter_names (`list[str]`):
list of adapters to send device to.
adapter_names (`List[str]`):
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
@@ -1007,7 +1007,7 @@ class LoraBaseMixin:
@staticmethod
def write_lora_layers(
state_dict: dict[str, torch.Tensor],
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
@@ -1059,9 +1059,9 @@ class LoraBaseMixin:
@classmethod
def _save_lora_weights(
cls,
save_directory: str | os.PathLike,
lora_layers: dict[str, dict[str, torch.nn.Module | torch.Tensor]],
lora_metadata: dict[str, Optional[dict]],
save_directory: Union[str, os.PathLike],
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
lora_metadata: Dict[str, Optional[dict]],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
+7 -26
View File
@@ -13,6 +13,7 @@
# limitations under the License.
import re
from typing import List
import torch
@@ -1020,7 +1021,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
return new_state_dict
def _custom_replace(key: str, substrings: list[str]) -> str:
def _custom_replace(key: str, substrings: List[str]) -> str:
# Replaces the "."s with "_"s upto the `substrings`.
# Example:
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
@@ -1976,34 +1977,14 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
if any("head.head" in k for k in original_state_dict):
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
f"head.head.{lora_up_key}.weight"
)
if any("head.head" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
# an identity.
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
if f"head.head.{lora_down_key}.weight" in state_dict:
logger.info(
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
)
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
).T
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:
+154 -154
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import Callable, Optional
from typing import Callable, Dict, List, Optional, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
@@ -137,7 +137,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -240,7 +240,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -267,7 +267,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -367,7 +367,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -429,7 +429,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -469,9 +469,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
@classmethod
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
unet_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
text_encoder_lora_layers: dict[str, torch.nn.Module] = None,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -485,9 +485,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
@@ -531,10 +531,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: list[str] = ["unet", "text_encoder"],
components: List[str] = ["unet", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -543,12 +543,12 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
> [!WARNING] > This is an experimental API.
Args:
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`list[str]`, *optional*):
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
@@ -572,7 +572,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
**kwargs,
)
def unfuse_lora(self, components: list[str] = ["unet", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -580,7 +580,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
> [!WARNING] > This is an experimental API.
Args:
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
@@ -602,7 +602,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -679,7 +679,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -706,7 +706,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -807,7 +807,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -870,7 +870,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -910,10 +910,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@classmethod
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
unet_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
text_encoder_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
text_encoder_2_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -957,10 +957,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: list[str] = ["unet", "text_encoder", "text_encoder_2"],
components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -974,7 +974,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
**kwargs,
)
def unfuse_lora(self, components: list[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -998,7 +998,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -1050,7 +1050,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
@@ -1166,7 +1166,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -1207,10 +1207,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
text_encoder_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
text_encoder_2_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1255,10 +1255,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: list[str] = ["transformer", "text_encoder", "text_encoder_2"],
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -1273,7 +1273,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: list[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -1293,7 +1293,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -1346,7 +1346,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -1421,8 +1421,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1455,10 +1455,10 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -1473,7 +1473,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -1497,7 +1497,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs,
):
@@ -1620,7 +1620,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -1782,7 +1782,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
transformer,
prefix=None,
discard_original_layers=False,
) -> dict[str, torch.Tensor]:
) -> Dict[str, torch.Tensor]:
# Remove prefix if present
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
@@ -1851,7 +1851,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -1892,9 +1892,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
text_encoder_lora_layers: dict[str, torch.nn.Module] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1908,9 +1908,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
text_encoder_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
@@ -1954,10 +1954,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -1984,7 +1984,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
**kwargs,
)
def unfuse_lora(self, components: list[str] = ["transformer", "text_encoder"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -1992,7 +1992,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
> [!WARNING] > This is an experimental API.
Args:
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
@@ -2341,7 +2341,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -2381,9 +2381,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
@classmethod
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
text_encoder_lora_layers: dict[str, torch.nn.Module] = None,
transformer_lora_layers: dict[str, torch.nn.Module] = None,
save_directory: Union[str, os.PathLike],
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -2395,9 +2395,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
@@ -2446,7 +2446,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -2498,7 +2498,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -2572,8 +2572,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
@classmethod
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -2605,10 +2605,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -2622,7 +2622,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
)
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -2642,7 +2642,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -2695,7 +2695,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -2770,8 +2770,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -2804,10 +2804,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -2822,7 +2822,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -2841,7 +2841,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -2898,7 +2898,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -2973,8 +2973,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -3007,10 +3007,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -3025,7 +3025,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -3045,7 +3045,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -3098,7 +3098,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -3173,8 +3173,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -3207,10 +3207,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -3225,7 +3225,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -3244,7 +3244,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -3301,7 +3301,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -3376,8 +3376,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -3410,10 +3410,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -3428,7 +3428,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -3447,7 +3447,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -3505,7 +3505,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -3580,8 +3580,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -3614,10 +3614,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -3632,7 +3632,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -3651,7 +3651,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -3669,7 +3669,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
Path to a directory where a downloaded pretrained model configuration is cached.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files.
@@ -3731,7 +3731,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -3832,8 +3832,8 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
@classmethod
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -3846,7 +3846,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to.
transformer_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process.
@@ -3879,22 +3879,22 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
Args:
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing.
adapter_names (`list[str]`, *optional*):
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing.
Example:
@@ -3914,12 +3914,12 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
**kwargs,
)
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of [`pipe.fuse_lora()`].
Args:
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -3936,7 +3936,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -4040,7 +4040,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -4139,8 +4139,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -4173,10 +4173,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -4191,7 +4191,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -4211,7 +4211,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -4317,7 +4317,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -4416,8 +4416,8 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -4450,10 +4450,10 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -4468,7 +4468,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -4488,7 +4488,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -4541,7 +4541,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -4616,8 +4616,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -4650,10 +4650,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -4668,7 +4668,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -4687,7 +4687,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -4744,7 +4744,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -4819,8 +4819,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -4853,10 +4853,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -4871,7 +4871,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
@@ -4890,7 +4890,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
@@ -4949,7 +4949,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
@@ -5024,8 +5024,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -5058,10 +5058,10 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[list[str]] = None,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
@@ -5076,7 +5076,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
+12 -12
View File
@@ -17,7 +17,7 @@ import json
import os
from functools import partial
from pathlib import Path
from typing import Dict, Literal, Optional
from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
@@ -113,7 +113,7 @@ class PeftAdapterMixin:
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -127,7 +127,7 @@ class PeftAdapterMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -447,16 +447,16 @@ class PeftAdapterMixin:
def set_adapters(
self,
adapter_names: list[str] | str,
weights: Optional[float | Dict | list[float] | list[Dict] | list[None]] = None,
adapter_names: Union[List[str], str],
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
Args:
adapter_names (`list[str]` or `str`):
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[list[float], float]`, *optional*):
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -539,7 +539,7 @@ class PeftAdapterMixin:
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: str | list[str]) -> None:
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
@@ -547,7 +547,7 @@ class PeftAdapterMixin:
[documentation](https://huggingface.co/docs/peft).
Args:
adapter_name (Union[str, list[str]])):
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in the case of a single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
@@ -633,7 +633,7 @@ class PeftAdapterMixin:
# support for older PEFT versions
module.disable_adapters = False
def active_adapters(self) -> list[str]:
def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.
@@ -756,12 +756,12 @@ class PeftAdapterMixin:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True)
def delete_adapters(self, adapter_names: list[str] | str):
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the underlying model.
Args:
adapter_names (`Union[list[str], str]`):
adapter_names (`Union[List[str], str]`):
The names (single string or list of strings) of the adapter to delete.
Example:
+1 -1
View File
@@ -290,7 +290,7 @@ class FromSingleFileMixin:
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
+1 -1
View File
@@ -229,7 +229,7 @@ class FromOriginalModelMixin:
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
+10 -10
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Dict, List, Optional, Union
import safetensors
import torch
@@ -112,7 +112,7 @@ class TextualInversionLoaderMixin:
Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: str | list[str], tokenizer: "PreTrainedTokenizer"): # noqa: F821
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
r"""
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
@@ -127,14 +127,14 @@ class TextualInversionLoaderMixin:
Returns:
`str` or list of `str`: The converted prompt
"""
if not isinstance(prompt, list):
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
if not isinstance(prompt, list):
if not isinstance(prompt, List):
return prompts[0]
return prompts
@@ -263,8 +263,8 @@ class TextualInversionLoaderMixin:
@validate_hf_hub_args
def load_textual_inversion(
self,
pretrained_model_name_or_path: str | list[str] | dict[str, torch.Tensor] | list[dict[str, torch.Tensor]],
token: Optional[str | list[str]] = None,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
**kwargs,
@@ -274,7 +274,7 @@ class TextualInversionLoaderMixin:
Automatic1111 formats are supported).
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `list[str or os.PathLike]` or `Dict` or `list[Dict]`):
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either one of the following or a list of them:
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
@@ -285,7 +285,7 @@ class TextualInversionLoaderMixin:
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
token (`str` or `list[str]`, *optional*):
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
@@ -306,7 +306,7 @@ class TextualInversionLoaderMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -458,7 +458,7 @@ class TextualInversionLoaderMixin:
def unload_textual_inversion(
self,
tokens: Optional[str | list[str]] = None,
tokens: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
text_encoder: Optional["PreTrainedModel"] = None,
):
+5 -5
View File
@@ -15,7 +15,7 @@ import os
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Callable
from typing import Callable, Dict, Union
import safetensors
import torch
@@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin:
unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], **kwargs):
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in
@@ -92,7 +92,7 @@ class UNet2DConditionLoadersMixin:
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -106,7 +106,7 @@ class UNet2DConditionLoadersMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`dict[str, float]`):
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
@@ -412,7 +412,7 @@ class UNet2DConditionLoadersMixin:
def save_attn_procs(
self,
save_directory: str | os.PathLike,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
+9 -7
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List, Union
from torch import nn
@@ -40,7 +40,9 @@ def _translate_into_actual_layer_name(name):
return ".".join((updown, block, attn))
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: list[float | Dict], default_scale=1.0):
def _maybe_expand_lora_scales(
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
@@ -62,9 +64,9 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: list[
def _maybe_expand_lora_scales_for_one_adapter(
scales: float | Dict,
blocks_with_transformer: dict[str, int],
transformer_per_block: dict[str, int],
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
model: nn.Module,
default_scale: float = 1.0,
):
@@ -74,9 +76,9 @@ def _maybe_expand_lora_scales_for_one_adapter(
Parameters:
scales (`Union[float, Dict]`):
Scales dict to expand.
blocks_with_transformer (`dict[str, int]`):
blocks_with_transformer (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`dict[str, int]`):
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has
E.g. turns
+2 -1
View File
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import torch
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: dict[str, torch.Tensor]):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
-10
View File
@@ -36,8 +36,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -84,7 +82,6 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -94,13 +91,11 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
@@ -137,8 +132,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -175,7 +168,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -189,7 +181,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
@@ -201,7 +192,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
+7 -5
View File
@@ -16,7 +16,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
@@ -187,17 +187,19 @@ class ContextParallelOutput:
# If the key is a string, it denotes the name of the parameter in the forward function.
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
# to be split across context parallel region.
ContextParallelInputType = dict[
str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...]
ContextParallelInputType = Dict[
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
]
# A dictionary where keys denote the output to be gathered across context parallel region, and the
# value denotes the gathering configuration.
ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...]
ContextParallelOutputType = Union[
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
]
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType]
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
+17 -17
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Optional
from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
@@ -34,11 +34,11 @@ class MultiAdapter(ModelMixin):
or saving.
Args:
adapters (`list[T2IAdapter]`, *optional*, defaults to None):
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
A list of `T2IAdapter` model instances.
"""
def __init__(self, adapters: list["T2IAdapter"]):
def __init__(self, adapters: List["T2IAdapter"]):
super(MultiAdapter, self).__init__()
self.num_adapter = len(adapters)
@@ -73,7 +73,7 @@ class MultiAdapter(ModelMixin):
self.total_downscale_factor = first_adapter_total_downscale_factor
self.downscale_factor = first_adapter_downscale_factor
def forward(self, xs: torch.Tensor, adapter_weights: Optional[list[float]] = None) -> list[torch.Tensor]:
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Args:
xs (`torch.Tensor`):
@@ -81,7 +81,7 @@ class MultiAdapter(ModelMixin):
models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
`num_adapter` * number of channel per image.
adapter_weights (`list[float]`, *optional*, defaults to None):
adapter_weights (`List[float]`, *optional*, defaults to None):
A list of floats representing the weights which will be multiplied by each adapter's output before
summing them together. If `None`, equal weights will be used for all adapters.
"""
@@ -104,7 +104,7 @@ class MultiAdapter(ModelMixin):
def save_pretrained(
self,
save_directory: str | os.PathLike,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = True,
@@ -145,7 +145,7 @@ class MultiAdapter(ModelMixin):
model_path_to_save = model_path_to_save + f"_{idx}"
@classmethod
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
@@ -165,7 +165,7 @@ class MultiAdapter(ModelMixin):
Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
@@ -229,7 +229,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to `3`):
The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
image.
channels (`list[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
determines the number of downsample blocks in the adapter.
num_res_blocks (`int`, *optional*, defaults to `2`):
@@ -244,7 +244,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels: int = 3,
channels: list[int] = [320, 640, 1280, 1280],
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
adapter_type: str = "full_adapter",
@@ -263,7 +263,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
"'full_adapter_xl' or 'light_adapter'."
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
each representing information extracted at a different scale from the input. The length of the list is
@@ -295,7 +295,7 @@ class FullAdapter(nn.Module):
def __init__(
self,
in_channels: int = 3,
channels: list[int] = [320, 640, 1280, 1280],
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
):
@@ -318,7 +318,7 @@ class FullAdapter(nn.Module):
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method processes the input tensor `x` through the FullAdapter model and performs operations including
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
@@ -345,7 +345,7 @@ class FullAdapterXL(nn.Module):
def __init__(
self,
in_channels: int = 3,
channels: list[int] = [320, 640, 1280, 1280],
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 16,
):
@@ -370,7 +370,7 @@ class FullAdapterXL(nn.Module):
# XL has only one downsampling AdapterBlock.
self.total_downscale_factor = downscale_factor * 2
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
@@ -473,7 +473,7 @@ class LightAdapter(nn.Module):
def __init__(
self,
in_channels: int = 3,
channels: list[int] = [320, 640, 1280],
channels: List[int] = [320, 640, 1280],
num_res_blocks: int = 4,
downscale_factor: int = 8,
):
@@ -496,7 +496,7 @@ class LightAdapter(nn.Module):
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
feature tensor corresponds to a different level of processing within the LightAdapter.
+13 -13
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
class AttentionMixin:
@property
def attn_processors(self) -> dict[str, AttentionProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
@@ -47,7 +47,7 @@ class AttentionMixin:
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
@@ -61,7 +61,7 @@ class AttentionMixin:
return processors
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -184,7 +184,7 @@ class AttentionModuleMixin:
def set_use_xla_flash_attention(
self,
use_xla_flash_attention: bool,
partition_spec: Optional[tuple[Optional[str], ...]] = None,
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
is_flux=False,
) -> None:
"""
@@ -193,7 +193,7 @@ class AttentionModuleMixin:
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`tuple[]`, *optional*):
partition_spec (`Tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
@@ -669,8 +669,8 @@ class JointTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -950,9 +950,9 @@ class BasicTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: dict[str, Any] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
@@ -1487,7 +1487,7 @@ class FreeNoiseTransformerBlock(nn.Module):
self._chunk_size = None
self._chunk_dim = 0
def _get_frame_indices(self, num_frames: int) -> list[tuple[int, int]]:
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
@@ -1495,7 +1495,7 @@ class FreeNoiseTransformerBlock(nn.Module):
frame_indices.append((window_start, window_end))
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> list[float]:
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "flat":
weights = [1.0] * num_frames
@@ -1545,7 +1545,7 @@ class FreeNoiseTransformerBlock(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: dict[str, Any] = None,
cross_attention_kwargs: Dict[str, Any] = None,
*args,
**kwargs,
) -> torch.Tensor:
+8 -70
View File
@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import inspect
import math
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
@@ -29,8 +27,6 @@ if torch.distributed.is_available():
from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
@@ -51,7 +47,6 @@ if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
@@ -59,7 +54,6 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -84,12 +78,6 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
@@ -190,9 +178,6 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# `aiter`
AITER = "aiter"
# PyTorch native
FLEX = "flex"
NATIVE = "native"
@@ -230,7 +215,7 @@ class _AttentionBackendRegistry:
def register(
cls,
backend: AttentionBackendName,
constraints: Optional[list[Callable]] = None,
constraints: Optional[List[Callable]] = None,
supports_context_parallel: bool = False,
):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
@@ -265,7 +250,7 @@ class _AttentionBackendRegistry:
@contextlib.contextmanager
def attention_backend(backend: str | AttentionBackendName = AttentionBackendName.NATIVE):
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
Context manager to set the active attention backend.
"""
@@ -293,7 +278,7 @@ def dispatch_attention_fn(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
attention_kwargs: Optional[dict[str, Any]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
*,
backend: Optional[AttentionBackendName] = None,
parallel_config: Optional["ParallelConfig"] = None,
@@ -429,12 +414,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
@@ -597,7 +576,7 @@ def _wrapped_flash_attn_3(
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
@@ -639,7 +618,7 @@ def _(
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
window_size = (-1, -1) # noqa: F841
# A lot of the parameters here are not yet used in any way within diffusers.
# We can safely ignore for now and keep the fake op shape propagation simple.
@@ -1337,7 +1316,7 @@ def _flash_attention_3_hub(
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: tuple[int, int] = (-1, -1),
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
@@ -1418,47 +1397,6 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
@@ -1467,7 +1405,7 @@ def _native_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor | "flex_attention.BlockMask"] = None,
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+68 -68
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
import math
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -309,7 +309,7 @@ class Attention(nn.Module):
def set_use_xla_flash_attention(
self,
use_xla_flash_attention: bool,
partition_spec: Optional[tuple[Optional[str], ...]] = None,
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
is_flux=False,
) -> None:
r"""
@@ -318,7 +318,7 @@ class Attention(nn.Module):
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`tuple[]`, *optional*):
partition_spec (`Tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
"""
if use_xla_flash_attention:
@@ -872,7 +872,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: tuple[int, ...] = (5,),
kernel_sizes: Tuple[int, ...] = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
@@ -2790,7 +2790,7 @@ class XLAFlashAttnProcessor2_0:
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
def __init__(self, partition_spec: Optional[tuple[Optional[str], ...]] = None):
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
@@ -3001,7 +3001,7 @@ class StableAudioAttnProcessor2_0:
def apply_partial_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: tuple[torch.Tensor],
freqs_cis: Tuple[torch.Tensor],
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
@@ -4212,9 +4212,9 @@ class IPAdapterAttnProcessor(nn.Module):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or list[`float`], defaults to 1.0):
scale (`float` or List[`float`], defaults to 1.0):
the weight scale of image prompt.
"""
@@ -4305,7 +4305,7 @@ class IPAdapterAttnProcessor(nn.Module):
hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, list):
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
@@ -4412,9 +4412,9 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `list[float]`, defaults to 1.0):
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
"""
@@ -4524,7 +4524,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.to(query.dtype)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, list):
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
@@ -4644,9 +4644,9 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `list[float]`, defaults to 1.0):
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
@@ -4763,7 +4763,7 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
if ip_hidden_states:
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, list):
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
@@ -5622,56 +5622,56 @@ CROSS_ATTENTION_PROCESSORS = (
FluxIPAdapterJointAttnProcessor2_0,
)
AttentionProcessor = (
AttnProcessor
| CustomDiffusionAttnProcessor
| AttnAddedKVProcessor
| AttnAddedKVProcessor2_0
| JointAttnProcessor2_0
| PAGJointAttnProcessor2_0
| PAGCFGJointAttnProcessor2_0
| FusedJointAttnProcessor2_0
| AllegroAttnProcessor2_0
| AuraFlowAttnProcessor2_0
| FusedAuraFlowAttnProcessor2_0
| FluxAttnProcessor2_0
| FluxAttnProcessor2_0_NPU
| FusedFluxAttnProcessor2_0
| FusedFluxAttnProcessor2_0_NPU
| CogVideoXAttnProcessor2_0
| FusedCogVideoXAttnProcessor2_0
| XFormersAttnAddedKVProcessor
| XFormersAttnProcessor
| XLAFlashAttnProcessor2_0
| AttnProcessorNPU
| AttnProcessor2_0
| MochiVaeAttnProcessor2_0
| MochiAttnProcessor2_0
| StableAudioAttnProcessor2_0
| HunyuanAttnProcessor2_0
| FusedHunyuanAttnProcessor2_0
| PAGHunyuanAttnProcessor2_0
| PAGCFGHunyuanAttnProcessor2_0
| LuminaAttnProcessor2_0
| FusedAttnProcessor2_0
| CustomDiffusionXFormersAttnProcessor
| CustomDiffusionAttnProcessor2_0
| SlicedAttnProcessor
| SlicedAttnAddedKVProcessor
| SanaLinearAttnProcessor2_0
| PAGCFGSanaLinearAttnProcessor2_0
| PAGIdentitySanaLinearAttnProcessor2_0
| SanaMultiscaleLinearAttention
| SanaMultiscaleAttnProcessor2_0
| SanaMultiscaleAttentionProjection
| IPAdapterAttnProcessor
| IPAdapterAttnProcessor2_0
| IPAdapterXFormersAttnProcessor
| SD3IPAdapterJointAttnProcessor2_0
| PAGIdentitySelfAttnProcessor2_0
| PAGCFGIdentitySelfAttnProcessor2_0
| LoRAAttnProcessor
| LoRAAttnProcessor2_0
| LoRAXFormersAttnProcessor
| LoRAAttnAddedKVProcessor
)
AttentionProcessor = Union[
AttnProcessor,
CustomDiffusionAttnProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
JointAttnProcessor2_0,
PAGJointAttnProcessor2_0,
PAGCFGJointAttnProcessor2_0,
FusedJointAttnProcessor2_0,
AllegroAttnProcessor2_0,
AuraFlowAttnProcessor2_0,
FusedAuraFlowAttnProcessor2_0,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0_NPU,
CogVideoXAttnProcessor2_0,
FusedCogVideoXAttnProcessor2_0,
XFormersAttnAddedKVProcessor,
XFormersAttnProcessor,
XLAFlashAttnProcessor2_0,
AttnProcessorNPU,
AttnProcessor2_0,
MochiVaeAttnProcessor2_0,
MochiAttnProcessor2_0,
StableAudioAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FusedHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
SlicedAttnProcessor,
SlicedAttnAddedKVProcessor,
SanaLinearAttnProcessor2_0,
PAGCFGSanaLinearAttnProcessor2_0,
PAGIdentitySanaLinearAttnProcessor2_0,
SanaMultiscaleLinearAttention,
SanaMultiscaleAttnProcessor2_0,
SanaMultiscaleAttentionProjection,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
]
+4 -4
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import Optional
from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args
@@ -37,7 +37,7 @@ class AutoModel(ConfigMixin):
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path: Optional[str | os.PathLike] = None, **kwargs):
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -61,7 +61,7 @@ class AutoModel(ConfigMixin):
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`dict[str, str]`, *optional*):
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info (`bool`, *optional*, defaults to `False`):
@@ -83,7 +83,7 @@ class AutoModel(ConfigMixin):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
@@ -5,8 +5,6 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.
@@ -34,16 +34,16 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
tuple of downsample block types.
down_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
tuple of down block output channels.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of down block output channels.
layers_per_down_block (`int`, *optional*, defaults to `1`):
Number layers for down block.
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
tuple of upsample block types.
up_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
tuple of up block output channels.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of up block output channels.
layers_per_up_block (`int`, *optional*, defaults to `1`):
Number layers for up block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
@@ -67,11 +67,11 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
down_block_out_channels: tuple[int, ...] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
down_block_out_channels: Tuple[int, ...] = (64,),
layers_per_down_block: int = 1,
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
up_block_out_channels: tuple[int, ...] = (64,),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
up_block_out_channels: Tuple[int, ...] = (64,),
layers_per_up_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -107,11 +107,14 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput | tuple[torch.Tensor]:
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
@@ -127,7 +130,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask)
@@ -144,7 +147,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
decoded = self._decode(z, image, mask).sample
if not return_dict:
@@ -159,7 +162,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
from .vae import DecoderOutput, EncoderOutput
class ResBlock(nn.Module):
@@ -68,7 +68,7 @@ class EfficientViTBlock(nn.Module):
in_channels: int,
mult: float = 1.0,
attention_head_dim: int = 32,
qkv_multiscales: tuple[int, ...] = (5,),
qkv_multiscales: Tuple[int, ...] = (5,),
norm_type: str = "batch_norm",
) -> None:
super().__init__()
@@ -102,7 +102,7 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: tuple[int] = (),
qkv_mutliscales: Tuple[int] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
@@ -205,10 +205,10 @@ class Encoder(nn.Module):
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str | tuple[str] = "ResBlock",
block_out_channels: tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple[int] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
):
@@ -291,12 +291,12 @@ class Decoder(nn.Module):
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str | tuple[str] = "ResBlock",
block_out_channels: tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple[int] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: str | tuple[str] = "rms_norm",
act_fn: str | tuple[str] = "silu",
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
conv_act_fn: str = "relu",
@@ -378,7 +378,7 @@ class Decoder(nn.Module):
return hidden_states
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629).
@@ -391,29 +391,29 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
The number of input channels in samples.
latent_channels (`int`, defaults to `32`):
The number of channels in the latent space representation.
encoder_block_types (`Union[str, tuple[str]]`, defaults to `"ResBlock"`):
encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
The type(s) of block to use in the encoder.
decoder_block_types (`Union[str, tuple[str]]`, defaults to `"ResBlock"`):
decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
The type(s) of block to use in the decoder.
encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
The number of output channels for each block in the encoder.
decoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
The number of output channels for each block in the decoder.
encoder_layers_per_block (`tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
The number of layers per block in the encoder.
decoder_layers_per_block (`tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
The number of layers per block in the decoder.
encoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
decoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
The type of block to use for upsampling in the decoder.
downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
The type of block to use for downsampling in the encoder.
decoder_norm_types (`Union[str, tuple[str]]`, defaults to `"rms_norm"`):
decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, tuple[str]]`, defaults to `"silu"`):
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder.
encoder_out_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the end of the encoder.
@@ -436,18 +436,18 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
in_channels: int = 3,
latent_channels: int = 32,
attention_head_dim: int = 32,
encoder_block_types: str | tuple[str] = "ResBlock",
decoder_block_types: str | tuple[str] = "ResBlock",
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: tuple[int] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: tuple[int] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle",
downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: str | tuple[str] = "rms_norm",
decoder_act_fns: str | tuple[str] = "silu",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
encoder_out_shortcut: bool = True,
decoder_in_shortcut: bool = True,
decoder_conv_act_fn: str = "relu",
@@ -536,6 +536,27 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
@@ -547,7 +568,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
return encoded
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput | tuple[torch.Tensor]:
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
r"""
Encode a batch of images into latents.
@@ -581,7 +602,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
return decoded
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Decode a batch of images.
@@ -665,7 +686,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
return (encoded,)
return EncoderOutput(latent=encoded)
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -32,10 +32,10 @@ from ..attention_processor import (
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -45,12 +45,12 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
tuple of downsample block types.
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
tuple of upsample block types.
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
tuple of block output channels.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str] = ("DownEncoderBlock2D",),
up_block_types: tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: tuple[int] = (64,),
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -88,8 +88,8 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
sample_size: int = 32,
scaling_factor: float = 0.18215,
shift_factor: Optional[float] = None,
latents_mean: Optional[tuple[float]] = None,
latents_std: Optional[tuple[float]] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: bool = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
@@ -138,9 +138,38 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> dict[str, AttentionProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
@@ -149,7 +178,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
@@ -164,7 +193,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -229,7 +258,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -255,7 +284,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
@@ -272,7 +301,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> DecoderOutput | torch.FloatTensor:
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
@@ -420,7 +449,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -475,7 +504,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -28,7 +28,6 @@ from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
from .vae import AutoencoderMixin
class AllegroTemporalConvLayer(nn.Module):
@@ -417,14 +416,14 @@ class AllegroEncoder3D(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str, ...] = (
down_block_types: Tuple[str, ...] = (
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
temporal_downsample_blocks: tuple[bool, ...] = [True, True, False, False],
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -544,14 +543,14 @@ class AllegroDecoder3D(nn.Module):
self,
in_channels: int = 4,
out_channels: int = 3,
up_block_types: tuple[str, ...] = (
up_block_types: Tuple[str, ...] = (
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
),
temporal_upsample_blocks: tuple[bool, ...] = [False, True, True, False],
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -674,7 +673,7 @@ class AllegroDecoder3D(nn.Module):
return sample
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -687,14 +686,14 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
Number of channels in the input image.
out_channels (int, defaults to `3`):
Number of channels in the output.
down_block_types (`tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
tuple of strings denoting which types of down blocks to use.
up_block_types (`tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
tuple of strings denoting which types of up blocks to use.
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
tuple of integers denoting number of output channels in each block.
temporal_downsample_blocks (`tuple[bool, ...]`, defaults to `(True, True, False, False)`):
tuple of booleans denoting which blocks to enable temporal downsampling in.
down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
Tuple of strings denoting which types of down blocks to use.
up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
Tuple of strings denoting which types of up blocks to use.
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
Tuple of integers denoting number of output channels in each block.
temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
Tuple of booleans denoting which blocks to enable temporal downsampling in.
latent_channels (`int`, defaults to `4`):
Number of channels in latents.
layers_per_block (`int`, defaults to `2`):
@@ -727,21 +726,21 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str, ...] = (
down_block_types: Tuple[str, ...] = (
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
"AllegroDownBlock3D",
),
up_block_types: tuple[str, ...] = (
up_block_types: Tuple[str, ...] = (
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
"AllegroUpBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
temporal_downsample_blocks: tuple[bool, ...] = (True, True, False, False),
temporal_upsample_blocks: tuple[bool, ...] = (False, True, True, False),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
latent_channels: int = 4,
layers_per_block: int = 2,
act_fn: str = "silu",
@@ -796,6 +795,35 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_size - self.tile_overlap_w,
)
def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
@@ -807,7 +835,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of videos into latents.
@@ -842,7 +870,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
raise NotImplementedError("Decoding without tiling has not been implemented yet.")
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of videos.
@@ -1045,7 +1073,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -72,7 +72,7 @@ class CogVideoXCausalConv3d(nn.Module):
Args:
in_channels (`int`): Number of channels in the input tensor.
out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (`int` or `tuple[int, int, int]`): Kernel size of the convolutional kernel.
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (`int`, defaults to `1`): Stride of the convolution.
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (`str`, defaults to `"constant"`): Padding mode.
@@ -82,7 +82,7 @@ class CogVideoXCausalConv3d(nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int],
kernel_size: Union[int, Tuple[int, int, int]],
stride: int = 1,
dilation: int = 1,
pad_mode: str = "constant",
@@ -174,7 +174,7 @@ class CogVideoXSpatialNorm3D(nn.Module):
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
@@ -289,7 +289,7 @@ class CogVideoXResnetBlock3D(nn.Module):
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
@@ -411,7 +411,7 @@ class CogVideoXDownBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXDownBlock3D` class."""
@@ -506,7 +506,7 @@ class CogVideoXMidBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXMidBlock3D` class."""
@@ -613,7 +613,7 @@ class CogVideoXUpBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class."""
@@ -652,10 +652,10 @@ class CogVideoXEncoder3D(nn.Module):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
@@ -671,13 +671,13 @@ class CogVideoXEncoder3D(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 16,
down_block_types: tuple[str, ...] = (
down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 256, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
@@ -744,7 +744,7 @@ class CogVideoXEncoder3D(nn.Module):
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class."""
@@ -805,9 +805,9 @@ class CogVideoXDecoder3D(nn.Module):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
@@ -823,13 +823,13 @@ class CogVideoXDecoder3D(nn.Module):
self,
in_channels: int = 16,
out_channels: int = 3,
up_block_types: tuple[str, ...] = (
up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 256, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
layers_per_block: int = 3,
act_fn: str = "silu",
norm_eps: float = 1e-6,
@@ -903,7 +903,7 @@ class CogVideoXDecoder3D(nn.Module):
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class."""
@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module):
return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -966,12 +966,12 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
tuple of downsample block types.
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
tuple of upsample block types.
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
tuple of block output channels.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str] = (
down_block_types: Tuple[str] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types: tuple[str] = (
up_block_types: Tuple[str] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: tuple[int] = (128, 256, 256, 512),
block_out_channels: Tuple[int] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
@@ -1018,8 +1018,8 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
sample_width: int = 720,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[tuple[float]] = None,
latents_std: Optional[tuple[float]] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
@@ -1124,6 +1124,27 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -1153,7 +1174,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -1178,7 +1199,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
@@ -1207,7 +1228,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -1321,7 +1342,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
enc = torch.cat(result_rows, dim=3)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1410,7 +1431,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor | torch.Tensor:
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -26,7 +24,7 @@ from ...utils import get_logger
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
from .vae import DecoderOutput, IdentityDistribution
logger = get_logger(__name__)
@@ -49,9 +47,9 @@ class CosmosCausalConv3d(nn.Conv3d):
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_size: int | tuple[int, int, int] = (3, 3, 3),
dilation: int | tuple[int, int, int] = (1, 1, 1),
stride: int | tuple[int, int, int] = (1, 1, 1),
kernel_size: Union[int, Tuple[int, int, int]] = (3, 3, 3),
dilation: Union[int, Tuple[int, int, int]] = (1, 1, 1),
stride: Union[int, Tuple[int, int, int]] = (1, 1, 1),
padding: int = 1,
pad_mode: str = "constant",
) -> None:
@@ -421,7 +419,7 @@ class CosmosCausalAttention(nn.Module):
attention_head_dim: int,
num_groups: int = 1,
dropout: float = 0.0,
processor: "CosmosSpatialAttentionProcessor2_0" | "CosmosTemporalAttentionProcessor2_0" = None,
processor: Union["CosmosSpatialAttentionProcessor2_0", "CosmosTemporalAttentionProcessor2_0"] = None,
) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -713,9 +711,9 @@ class CosmosEncoder3d(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 16,
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
num_resnet_blocks: int = 2,
attention_resolutions: tuple[int, ...] = (32,),
attention_resolutions: Tuple[int, ...] = (32,),
resolution: int = 1024,
patch_size: int = 4,
patch_type: str = "haar",
@@ -797,9 +795,9 @@ class CosmosDecoder3d(nn.Module):
self,
in_channels: int = 16,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
num_resnet_blocks: int = 2,
attention_resolutions: tuple[int, ...] = (32,),
attention_resolutions: Tuple[int, ...] = (32,),
resolution: int = 1024,
patch_size: int = 4,
patch_type: str = "haar",
@@ -877,7 +875,7 @@ class CosmosDecoder3d(nn.Module):
return hidden_states
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
@@ -888,12 +886,12 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
Number of output channels.
latent_channels (`int`, defaults to `16`):
Number of latent channels.
encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
Number of output channels for each encoder down block.
decode_block_out_channels (`tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
Number of output channels for each decoder up block.
attention_resolutions (`tuple[int, ...]`, defaults to `(32,)`):
list of image/video resolutions at which to apply attention.
attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
List of image/video resolutions at which to apply attention.
resolution (`int`, defaults to `1024`):
Base image/video resolution used for computing whether a block should have attention layers.
num_layers (`int`, defaults to `2`):
@@ -926,9 +924,9 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 16,
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
decode_block_out_channels: tuple[int, ...] = (256, 512, 512, 512),
attention_resolutions: tuple[int, ...] = (32,),
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
decode_block_out_channels: Tuple[int, ...] = (256, 512, 512, 512),
attention_resolutions: Tuple[int, ...] = (32,),
resolution: int = 1024,
num_layers: int = 2,
patch_size: int = 4,
@@ -936,8 +934,8 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
scaling_factor: float = 1.0,
spatial_compression_ratio: int = 8,
temporal_compression_ratio: int = 8,
latents_mean: Optional[list[float]] = LATENTS_MEAN,
latents_std: Optional[list[float]] = LATENTS_STD,
latents_mean: Optional[List[float]] = LATENTS_MEAN,
latents_std: Optional[List[float]] = LATENTS_STD,
) -> None:
super().__init__()
@@ -1033,6 +1031,27 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
@@ -1052,7 +1071,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
@@ -1061,7 +1080,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
@@ -1078,7 +1097,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor] | DecoderOutput:
) -> Union[Tuple[torch.Tensor], DecoderOutput]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -26,7 +26,7 @@ from ..activations import get_activation
from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -50,10 +50,10 @@ class HunyuanVideoCausalConv3d(nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int] = 3,
stride: int | tuple[int, int, int] = 1,
padding: int | tuple[int, int, int] = 0,
dilation: int | tuple[int, int, int] = 1,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
@@ -86,7 +86,7 @@ class HunyuanVideoUpsampleCausal3D(nn.Module):
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
upsample_factor: tuple[float, float, float] = (2, 2, 2),
upsample_factor: Tuple[float, float, float] = (2, 2, 2),
) -> None:
super().__init__()
@@ -357,7 +357,7 @@ class HunyuanVideoUpBlock3D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
add_upsample: bool = True,
upsample_scale_factor: tuple[int, int, int] = (2, 2, 2),
upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
) -> None:
super().__init__()
resnets = []
@@ -418,13 +418,13 @@ class HunyuanVideoEncoder3D(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str, ...] = (
down_block_types: Tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -526,13 +526,13 @@ class HunyuanVideoDecoder3D(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: tuple[str, ...] = (
up_block_types: Tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -624,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -641,19 +641,19 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 16,
down_block_types: tuple[str, ...] = (
down_block_types: Tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
up_block_types: tuple[str, ...] = (
up_block_types: Tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int] = (128, 256, 512, 512),
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
@@ -763,6 +763,27 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -779,7 +800,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
@@ -804,7 +825,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
@@ -825,7 +846,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
@@ -924,7 +945,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1013,7 +1034,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
@@ -1055,7 +1076,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -1,709 +0,0 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageResnetBlock(nn.Module):
r"""
Residual block with two convolutions and optional channel change.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x):
# Apply shortcut connection
residual = x
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
# Add residual connection
return x + residual
class HunyuanImageAttentionBlock(nn.Module):
r"""
Self-attention with a single head.
Args:
in_channels (int): The number of channels in the input tensor.
"""
def __init__(self, in_channels: int):
super().__init__()
# layers
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
self.proj = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
identity = x
x = self.norm(x)
# compute query, key, value
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, height, width = query.shape
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
# apply attention
x = F.scaled_dot_product_attention(query, key, value)
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
# output projection
x = self.proj(x)
return x + identity
class HunyuanImageDownsample(nn.Module):
"""
Downsampling block for spatial reduction.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
if out_channels % factor != 0:
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.group_size = factor * in_channels // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
h = h.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = x.shape
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
B, C, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageUpsample(nn.Module):
"""
Upsampling block for spatial expansion.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
factor = 4
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
self.repeats = factor * out_channels // in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv(x)
B, C, H, W = h.shape
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
h = h.reshape(B, C // 4, H * 2, W * 2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
B, C, H, W = shortcut.shape
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
return h + shortcut
class HunyuanImageMidBlock(nn.Module):
"""
Middle block for HunyuanImageVAE encoder and decoder.
Args:
in_channels (int): Number of input channels.
num_layers (int): Number of layers.
"""
def __init__(self, in_channels: int, num_layers: int = 1):
super().__init__()
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
attentions = []
for _ in range(num_layers):
attentions.append(HunyuanImageAttentionBlock(in_channels))
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnets[0](x)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
x = attn(x)
x = resnet(x)
return x
class HunyuanImageEncoder2D(nn.Module):
r"""
Encoder network that compresses input to latent representation.
Args:
in_channels (int): Number of input channels.
z_channels (int): Number of latent channels.
block_out_channels (list of int): Output channels for each block.
num_res_blocks (int): Number of residual blocks per block.
spatial_compression_ratio (int): Spatial downsampling factor.
non_linearity (str): Type of non-linearity to use. Default is "silu".
downsample_match_channel (bool): Whether to match channels during downsampling.
"""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
non_linearity: str = "silu",
downsample_match_channel: bool = True,
):
super().__init__()
if block_out_channels[-1] % (2 * z_channels) != 0:
raise ValueError(
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
)
self.in_channels = in_channels
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.spatial_compression_ratio = spatial_compression_ratio
self.group_size = block_out_channels[-1] // (2 * z_channels)
self.nonlinearity = get_activation(non_linearity)
# init block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
block_in_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
# residual blocks
for _ in range(num_res_blocks):
self.down_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# downsample block
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if downsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.down_blocks.append(
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
# middle blocks
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
# output blocks
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
## downsamples
for down_block in self.down_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(down_block, x)
else:
x = down_block(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(self.mid_block, x)
else:
x = self.mid_block(x)
## head
B, C, H, W = x.shape
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
x = self.norm_out(x)
x = self.nonlinearity(x)
x = self.conv_out(x)
return x + residual
class HunyuanImageDecoder2D(nn.Module):
r"""
Decoder network that reconstructs output from latent representation.
Args:
z_channels : int
Number of latent channels.
out_channels : int
Number of output channels.
block_out_channels : Tuple[int, ...]
Output channels for each block.
num_res_blocks : int
Number of residual blocks per block.
spatial_compression_ratio : int
Spatial upsampling factor.
upsample_match_channel : bool
Whether to match channels during upsampling.
non_linearity (str): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
spatial_compression_ratio: int,
upsample_match_channel: bool = True,
non_linearity: str = "silu",
):
super().__init__()
if block_out_channels[0] % z_channels != 0:
raise ValueError(
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
)
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.repeat = block_out_channels[0] // z_channels
self.spatial_compression_ratio = spatial_compression_ratio
self.nonlinearity = get_activation(non_linearity)
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# Middle blocks with attention
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
# Upsampling blocks
block_in_channel = block_out_channels[0]
self.up_blocks = nn.ModuleList()
for i in range(len(block_out_channels)):
block_out_channel = block_out_channels[i]
for _ in range(self.num_res_blocks + 1):
self.up_blocks.append(
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
)
block_in_channel = block_out_channel
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
if upsample_match_channel:
block_out_channel = block_out_channels[i + 1]
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
block_in_channel = block_out_channel
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid_block, h)
else:
h = self.mid_block(h)
for up_block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(up_block, h)
else:
h = up_block(h)
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# fmt: off
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
spatial_compression_ratio: int,
sample_size: int,
scaling_factor: float = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
) -> None:
# fmt: on
super().__init__()
self.encoder = HunyuanImageEncoder2D(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageDecoder2D(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
# Tiling and slicing configuration
self.use_slicing = False
self.use_tiling = False
# Tiling parameters
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // spatial_compression_ratio
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_size: Optional[int] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
allow processing larger images.
Args:
tile_sample_min_size (`int`, *optional*):
The minimum size required for a sample to be separated into tiles across the spatial dimension.
tile_overlap_factor (`float`, *optional*):
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
"""
self.use_tiling = True
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self.tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
batch_size, num_channels, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode input using spatial tiling strategy.
Args:
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
Returns:
`torch.Tensor`:
The latent representation of the encoded images.
"""
_, _, _, height, width = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode latent using spatial tiling strategy.
Args:
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, height, width = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, height, overlap_size):
row = []
for j in range(0, width, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -1,934 +0,0 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanImageRefinerCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
return self.conv(hidden_states)
class HunyuanImageRefinerRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.norm(x)
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, frames, height, width = query.shape
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
# batch_size, 1, frames * height * width, channels
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
x = self.proj_out(x)
return x + identity
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
@staticmethod
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
Args:
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
r1: temporal upsampling factor
r2: height upsampling factor
r3: width upsampling factor
"""
b, packed_c, f, h, w = tensor.shape
factor = r1 * r2 * r3
c = packed_c // factor
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
h = h[:, : h.shape[1] // 2]
# shortcut computation
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
else:
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
return h + shortcut
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
@staticmethod
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
This packs spatial/temporal dimensions into channels (opposite of upsample)
"""
b, c, packed_f, packed_h, packed_w = tensor.shape
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
h = torch.cat([h, h], dim=1)
# shortcut computation
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
else:
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class HunyuanImageRefinerResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class HunyuanImageRefinerMidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
add_attention: bool = True,
) -> None:
super().__init__()
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
else:
attentions.append(None)
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanImageRefinerDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample_out_channels: Optional[int] = None,
add_temporal_downsample: int = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample_out_channels is not None:
self.downsamplers = nn.ModuleList(
[
HunyuanImageRefinerDownsampleDCAE(
out_channels,
out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample_out_channels: Optional[int] = None,
add_temporal_upsample: bool = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanImageRefinerResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample_out_channels is not None:
self.upsamplers = nn.ModuleList(
[
HunyuanImageRefinerUpsampleDCAE(
out_channels,
out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanImageRefinerEncoder3D(nn.Module):
r"""
3D vae encoder for HunyuanImageRefiner.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 16,
downsample_match_channel: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.group_size = block_out_channels[-1] // self.out_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
output_channel = block_out_channels[i]
if not add_spatial_downsample:
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=None,
add_temporal_downsample=False,
)
input_channel = output_channel
else:
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
down_block = HunyuanImageRefinerDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
input_channel = downsample_out_channels
self.down_blocks.append(down_block)
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states += short_cut
return hidden_states
class HunyuanImageRefinerDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
upsample_match_channel: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.in_channels = in_channels
self.out_channels = out_channels
self.repeat = block_out_channels[0] // self.in_channels
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
# up
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
output_channel = block_out_channels[i]
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
if add_spatial_upsample or add_temporal_upsample:
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
input_channel = upsample_out_channels
else:
up_block = HunyuanImageRefinerUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=None,
add_temporal_upsample=False,
)
input_channel = output_channel
self.up_blocks.append(up_block)
# out
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
scaling_factor: float = 1.03682,
) -> None:
super().__init__()
self.encoder = HunyuanImageRefinerEncoder3D(
in_channels=in_channels,
out_channels=latent_channels * 2,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanImageRefinerDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
x = self.encoder(x)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, _, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
class LTXVideoCausalConv3d(nn.Module):
@@ -34,9 +34,9 @@ class LTXVideoCausalConv3d(nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int] = 3,
stride: int | tuple[int, int, int] = 1,
dilation: int | tuple[int, int, int] = 1,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
padding_mode: str = "zeros",
is_causal: bool = True,
@@ -201,7 +201,7 @@ class LTXVideoDownsampler3d(nn.Module):
self,
in_channels: int,
out_channels: int,
stride: int | tuple[int, int, int] = 1,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
@@ -249,7 +249,7 @@ class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
stride: int | tuple[int, int, int] = 1,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
@@ -735,11 +735,11 @@ class LTXVideoEncoder3d(nn.Module):
Number of input channels.
out_channels (`int`, defaults to 128):
Number of latent channels.
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal downscaling layers or not.
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
@@ -755,16 +755,16 @@ class LTXVideoEncoder3d(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
down_block_types: tuple[str, ...] = (
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -888,11 +888,11 @@ class LTXVideoDecoder3d(nn.Module):
Number of latent channels.
out_channels (`int`, defaults to 3):
Number of output channels.
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal upscaling layers or not.
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
@@ -910,17 +910,17 @@ class LTXVideoDecoder3d(nn.Module):
self,
in_channels: int = 128,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: tuple[bool, ...] = (False, False, False, False),
inject_noise: Tuple[bool, ...] = (False, False, False, False),
timestep_conditioning: bool = False,
upsample_residual: tuple[bool, ...] = (False, False, False, False),
upsample_factor: tuple[bool, ...] = (1, 1, 1, 1),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
) -> None:
super().__init__()
@@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module):
return hidden_states
class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1049,11 +1049,11 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
Number of output channels.
latent_channels (`int`, defaults to `128`):
Number of latent channels.
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal downscaling or not.
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
@@ -1082,22 +1082,22 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
down_block_types: tuple[str, ...] = (
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False, False),
downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: tuple[bool, ...] = (False, False, False, False),
upsample_factor: tuple[int, ...] = (1, 1, 1, 1),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
patch_size: int = 4,
patch_size_t: int = 1,
@@ -1219,6 +1219,27 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -1235,7 +1256,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -1261,7 +1282,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
def _decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
@@ -1283,7 +1304,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
@apply_forward_hook
def decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -1390,7 +1411,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
def tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1480,7 +1501,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
@@ -1523,7 +1544,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor | torch.Tensor:
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -37,10 +37,10 @@ class EasyAnimateCausalConv3d(nn.Conv3d):
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...] = 3,
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] = 1,
dilation: int | tuple[int, ...] = 1,
kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
@@ -437,13 +437,13 @@ class EasyAnimateEncoder(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 8,
down_block_types: tuple[str, ...] = (
down_block_types: Tuple[str, ...] = (
"SpatialDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
),
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -553,13 +553,13 @@ class EasyAnimateDecoder(nn.Module):
self,
in_channels: int = 8,
out_channels: int = 3,
up_block_types: tuple[str, ...] = (
up_block_types: Tuple[str, ...] = (
"SpatialUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
),
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module):
return hidden_states
class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
@@ -680,14 +680,14 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
in_channels: int = 3,
latent_channels: int = 16,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
down_block_types: tuple[str, ...] = [
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
down_block_types: Tuple[str, ...] = [
"SpatialDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
],
up_block_types: tuple[str, ...] = [
up_block_types: Tuple[str, ...] = [
"SpatialUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
@@ -805,10 +805,31 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def _encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -838,7 +859,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -863,7 +884,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
@@ -890,7 +911,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -983,7 +1004,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return moments
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
@@ -1050,7 +1071,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -14,7 +14,7 @@
# limitations under the License.
import functools
from typing import Optional
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -106,7 +106,7 @@ class MochiResnetBlock3D(nn.Module):
def forward(
self,
inputs: torch.Tensor,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
new_conv_cache = {}
conv_cache = conv_cache or {}
@@ -193,7 +193,7 @@ class MochiDownBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
chunk_size: int = 2**15,
) -> torch.Tensor:
r"""Forward method of the `MochiUpBlock3D` class."""
@@ -294,7 +294,7 @@ class MochiMidBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `MochiMidBlock3D` class."""
@@ -368,7 +368,7 @@ class MochiUpBlock3D(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
conv_cache: Optional[dict[str, torch.Tensor]] = None,
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
r"""Forward method of the `MochiUpBlock3D` class."""
@@ -445,13 +445,13 @@ class MochiEncoder3D(nn.Module):
The number of input channels.
out_channels (`int`, *optional*):
The number of output channels.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
The number of output channels for each block.
layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
The number of resnet blocks for each block.
temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
The temporal expansion factor for each of the up blocks.
spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
The spatial expansion factor for each of the up blocks.
non_linearity (`str`, *optional*, defaults to `"swish"`):
The non-linearity to use in the decoder.
@@ -461,11 +461,11 @@ class MochiEncoder3D(nn.Module):
self,
in_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...] = (128, 256, 512, 768),
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
temporal_expansions: tuple[int, ...] = (1, 2, 3),
spatial_expansions: tuple[int, ...] = (2, 2, 2),
add_attention_block: tuple[bool, ...] = (False, True, True, True, True),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
act_fn: str = "swish",
):
super().__init__()
@@ -500,7 +500,7 @@ class MochiEncoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
r"""Forward method of the `MochiEncoder3D` class."""
@@ -558,13 +558,13 @@ class MochiDecoder3D(nn.Module):
The number of input channels.
out_channels (`int`, *optional*):
The number of output channels.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
The number of output channels for each block.
layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
The number of resnet blocks for each block.
temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
The temporal expansion factor for each of the up blocks.
spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
The spatial expansion factor for each of the up blocks.
non_linearity (`str`, *optional*, defaults to `"swish"`):
The non-linearity to use in the decoder.
@@ -574,10 +574,10 @@ class MochiDecoder3D(nn.Module):
self,
in_channels: int, # 12
out_channels: int, # 3
block_out_channels: tuple[int, ...] = (128, 256, 512, 768),
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
temporal_expansions: tuple[int, ...] = (1, 2, 3),
spatial_expansions: tuple[int, ...] = (2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
act_fn: str = "swish",
):
super().__init__()
@@ -613,7 +613,7 @@ class MochiDecoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
r"""Forward method of the `MochiDecoder3D` class."""
@@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module):
return hidden_states, new_conv_cache
class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLMochi(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models).
@@ -668,8 +668,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
tuple of block output channels.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
@@ -688,15 +688,15 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 15,
out_channels: int = 3,
encoder_block_out_channels: tuple[int] = (64, 128, 256, 384),
decoder_block_out_channels: tuple[int] = (128, 256, 512, 768),
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
latent_channels: int = 12,
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu",
temporal_expansions: tuple[int, ...] = (1, 2, 3),
spatial_expansions: tuple[int, ...] = (2, 2, 2),
add_attention_block: tuple[bool, ...] = (False, True, True, True, True),
latents_mean: tuple[float, ...] = (
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
latents_mean: Tuple[float, ...] = (
-0.06730895953510081,
-0.038011381506090416,
-0.07477820912866141,
@@ -710,7 +710,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
-0.011931556316503654,
-0.0321993391887285,
),
latents_std: tuple[float, ...] = (
latents_std: Tuple[float, ...] = (
0.9263795028493863,
0.9248894543193766,
0.9393059390890617,
@@ -818,6 +818,27 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _enable_framewise_encoding(self):
r"""
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
@@ -860,7 +881,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -885,7 +906,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
@@ -915,7 +936,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -1013,7 +1034,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1097,7 +1118,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor | torch.Tensor:
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
@@ -18,7 +18,7 @@
# - GitHub: https://github.com/Wan-Video/Wan2.1
# - arXiv: https://arxiv.org/abs/2503.20314
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -31,7 +31,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -58,9 +58,9 @@ class QwenImageCausalConv3d(nn.Conv3d):
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int],
stride: int | tuple[int, int, int] = 1,
padding: int | tuple[int, int, int] = 0,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
@@ -663,7 +663,7 @@ class QwenImageDecoder3d(nn.Module):
return x
class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -679,13 +679,13 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: list[float] = [],
temperal_downsample: list[bool] = [False, True, True],
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: list[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
latents_std: list[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
) -> None:
# fmt: on
super().__init__()
@@ -763,6 +763,27 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
def _count_conv3d(model):
count = 0
@@ -806,7 +827,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
@@ -856,7 +877,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return DecoderOutput(sample=out)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
@@ -962,7 +983,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1031,7 +1052,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module):
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -145,10 +145,10 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
tuple of downsample block types.
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
tuple of block output channels.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: tuple[int] = (64,),
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
@@ -204,7 +204,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> dict[str, AttentionProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
@@ -213,7 +213,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
@@ -228,7 +228,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -278,7 +278,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -308,7 +308,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
z: torch.Tensor,
num_frames: int,
return_dict: bool = True,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -339,7 +339,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -25,7 +25,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -149,9 +149,9 @@ class WanCausalConv3d(nn.Conv3d):
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int],
stride: int | tuple[int, int, int] = 1,
padding: int | tuple[int, int, int] = 0,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
return x
@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.downsampler(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.upsampler(x, feat_cache, feat_idx)
else:
x = self.upsampler(x)
@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -951,7 +951,7 @@ def unpatchify(x, patch_size):
return x
class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -961,9 +961,6 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""
_supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(
@@ -971,12 +968,12 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: list[float] = [],
temperal_downsample: list[bool] = [False, True, True],
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: list[float] = [
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
@@ -994,7 +991,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
0.2503,
-0.2921,
],
latents_std: list[float] = [
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
@@ -1113,6 +1110,27 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1153,7 +1171,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
@@ -1209,7 +1227,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
return DecoderOutput(sample=out)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
@@ -1315,7 +1333,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1337,18 +1355,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width
if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = self.tile_sample_min_height - tile_sample_stride_height
blend_width = self.tile_sample_min_width - tile_sample_stride_width
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1362,9 +1371,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1380,15 +1387,11 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)
dec = torch.clamp(dec, min=-1.0, max=1.0)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@@ -1399,7 +1402,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | torch.Tensor:
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -25,7 +25,6 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
class Snake1d(nn.Module):
@@ -292,7 +291,7 @@ class OobleckDecoder(nn.Module):
return hidden_state
class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderOobleck(ModelMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
@@ -303,9 +302,9 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
Parameters:
encoder_hidden_size (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the encoder.
downsampling_ratios (`list[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
channel_multiples (`list[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
Multiples used to determine the hidden sizes of the hidden layers.
decoder_channels (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the decoder.
@@ -357,10 +356,24 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
self.use_slicing = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderOobleckOutput | tuple[OobleckDiagonalGaussianDistribution]:
) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -386,7 +399,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
return AutoencoderOobleckOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> OobleckDecoderOutput | torch.Tensor:
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
dec = self.decoder(z)
if not return_dict:
@@ -397,7 +410,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> OobleckDecoderOutput | torch.FloatTensor:
) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
@@ -429,7 +442,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> OobleckDecoderOutput | torch.Tensor:
) -> Union[OobleckDecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -14,7 +14,7 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, Union
import torch
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
@dataclass
@@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput):
latents: torch.Tensor
class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderTiny(ModelMixin, ConfigMixin):
r"""
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
@@ -50,11 +50,11 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
Parameters:
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
encoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
tuple of integers representing the number of output channels for each encoder block. The length of the
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
Tuple of integers representing the number of output channels for each encoder block. The length of the
tuple should be equal to the number of encoder blocks.
decoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
tuple of integers representing the number of output channels for each decoder block. The length of the
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
Tuple of integers representing the number of output channels for each decoder block. The length of the
tuple should be equal to the number of decoder blocks.
act_fn (`str`, *optional*, defaults to `"relu"`):
Activation function to be used throughout the model.
@@ -64,12 +64,12 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
upsampling process.
num_encoder_blocks (`tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
number of encoder blocks.
num_decoder_blocks (`tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
number of decoder blocks.
latent_magnitude (`float`, *optional*, defaults to 3.0):
@@ -99,14 +99,14 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
encoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64),
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu",
upsample_fn: str = "nearest",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: tuple[int, ...] = (3, 3, 3, 1),
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
force_upcast: bool = False,
@@ -162,6 +162,35 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
@@ -258,7 +287,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
return out
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderTinyOutput | tuple[torch.Tensor]:
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
@@ -275,7 +304,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def decode(
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
@@ -293,7 +322,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
sample: torch.Tensor,
return_dict: bool = True,
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -32,7 +32,7 @@ from ..attention_processor import (
)
from ..modeling_utils import ModelMixin
from ..unets.unet_2d import UNet2DModel
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution"
class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
r"""
The consistency decoder used with DALL-E 3.
@@ -77,9 +77,9 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
latent_channels: int = 4,
sample_size: int = 32,
encoder_act_fn: str = "silu",
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_double_z: bool = True,
encoder_down_block_types: tuple[str, ...] = (
encoder_down_block_types: Tuple[str, ...] = (
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
@@ -90,8 +90,8 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
encoder_norm_num_groups: int = 32,
encoder_out_channels: int = 4,
decoder_add_attention: bool = False,
decoder_block_out_channels: tuple[int, ...] = (320, 640, 1024, 1024),
decoder_down_block_types: tuple[str, ...] = (
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
decoder_down_block_types: Tuple[str, ...] = (
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
@@ -106,7 +106,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
decoder_out_channels: int = 6,
decoder_resnet_time_scale_shift: str = "scale_shift",
decoder_time_embedding_type: str = "learned",
decoder_up_block_types: tuple[str, ...] = (
decoder_up_block_types: Tuple[str, ...] = (
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
@@ -167,9 +167,42 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> dict[str, AttentionProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
@@ -178,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
@@ -193,7 +226,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -246,7 +279,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> ConsistencyDecoderVAEOutput | tuple[DiagonalGaussianDistribution]:
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -285,7 +318,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
num_inference_steps: int = 2,
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
"""
Decodes the input latent vector `z` using the consistency decoder VAE model.
@@ -296,7 +329,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
num_inference_steps (int): The number of inference steps. Default is 2.
Returns:
Union[DecoderOutput, tuple[torch.Tensor]]: The decoded output.
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
"""
z = (z * self.config.scaling_factor - self.means) / self.stds
@@ -339,7 +372,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput | tuple:
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
@@ -400,7 +433,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> DecoderOutput | tuple[torch.Tensor]:
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
+27 -59
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple
import numpy as np
import torch
@@ -66,10 +66,10 @@ class Encoder(nn.Module):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
@@ -85,8 +85,8 @@ class Encoder(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: tuple[int, ...] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -187,9 +187,9 @@ class Decoder(nn.Module):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
@@ -205,8 +205,8 @@ class Decoder(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: tuple[int, ...] = (64,),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -286,9 +286,11 @@ class Decoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -296,6 +298,7 @@ class Decoder(nn.Module):
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -402,9 +405,9 @@ class MaskConditionDecoder(nn.Module):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
@@ -420,8 +423,8 @@ class MaskConditionDecoder(nn.Module):
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: tuple[int, ...] = (64,),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
@@ -633,7 +636,7 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple]:
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
@@ -667,7 +670,7 @@ class VectorQuantizer(nn.Module):
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices: torch.LongTensor, shape: tuple[int, ...]) -> torch.Tensor:
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
@@ -728,7 +731,7 @@ class DiagonalGaussianDistribution(object):
dim=[1, 2, 3],
)
def nll(self, sample: torch.Tensor, dims: tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
@@ -761,10 +764,10 @@ class EncoderTiny(nn.Module):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`tuple[int, ...]`):
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`tuple[int, ...]`):
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
@@ -774,8 +777,8 @@ class EncoderTiny(nn.Module):
self,
in_channels: int,
out_channels: int,
num_blocks: tuple[int, ...],
block_out_channels: tuple[int, ...],
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
act_fn: str,
):
super().__init__()
@@ -827,10 +830,10 @@ class DecoderTiny(nn.Module):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`tuple[int, ...]`):
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`tuple[int, ...]`):
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
upsampling_scaling_factor (`int`):
The scaling factor to use for upsampling.
@@ -842,8 +845,8 @@ class DecoderTiny(nn.Module):
self,
in_channels: int,
out_channels: int,
num_blocks: tuple[int, ...],
block_out_channels: tuple[int, ...],
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int,
act_fn: str,
upsample_fn: str,
@@ -891,38 +894,3 @@ class DecoderTiny(nn.Module):
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
class AutoencoderMixin:
def enable_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
if not hasattr(self, "use_tiling"):
raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.")
self.use_tiling = True
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
if not hasattr(self, "use_slicing"):
raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.")
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

Some files were not shown because too many files have changed in this diff Show More