Compare commits
40 Commits
vace-fix
...
modular-docs
| Author | SHA1 | Date | |
|---|---|---|---|
| 6d9c5a8d3a | |||
| a9cb08af39 | |||
| d6f66f4946 | |||
| 9f669e7b5d | |||
| 8ac17cd2cb | |||
| e4393fa613 | |||
| b3e9dfced7 | |||
| 58f3771545 | |||
| 6198f8a12b | |||
| dcfb18a2d3 | |||
| ac5a1e28fc | |||
| 325a95051b | |||
| 1ec28a2c77 | |||
| de6173c683 | |||
| 8f80dda193 | |||
| cdbf0ad883 | |||
| 5e8415a311 | |||
| 051c8a1c0f | |||
| d54622c267 | |||
| df8dd77817 | |||
| 9f3c0fdcd8 | |||
| 84e16575e4 | |||
| 55d49d4379 | |||
| 40528e9ae7 | |||
| dc622a95d0 | |||
| ecfbc8f952 | |||
| df0e2a4f2c | |||
| 303efd2b8d | |||
| 5afbcce176 | |||
| 6d1a648602 | |||
| 250f5cb53d | |||
| dc6bd1511a | |||
| 500b9cf184 | |||
| d34b18c783 | |||
| 7536f647e4 | |||
| a138d71ec1 | |||
| bc4039886d | |||
| 9c3b58dcf1 | |||
| 74b5fed434 | |||
| 85eb505672 |
@@ -323,6 +323,8 @@
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/transformer_bria_fibo
|
||||
title: BriaFiboTransformer2DModel
|
||||
- local: api/models/bria_transformer
|
||||
title: BriaTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
@@ -347,6 +349,8 @@
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuanimage_transformer_2d
|
||||
title: HunyuanImageTransformer2DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
@@ -369,6 +373,8 @@
|
||||
title: QwenImageTransformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sana_video_transformer3d
|
||||
title: SanaVideoTransformer3DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/skyreels_v2_transformer_3d
|
||||
@@ -411,6 +417,10 @@
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_cosmos
|
||||
title: AutoencoderKLCosmos
|
||||
- local: api/models/autoencoder_kl_hunyuanimage
|
||||
title: AutoencoderKLHunyuanImage
|
||||
- local: api/models/autoencoder_kl_hunyuanimage_refiner
|
||||
title: AutoencoderKLHunyuanImageRefiner
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
@@ -463,6 +473,8 @@
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/bria_fibo
|
||||
title: Bria Fibo
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
@@ -553,6 +565,8 @@
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/sana_video
|
||||
title: Sana Video
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
@@ -620,10 +634,14 @@
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hunyuanimage21
|
||||
title: HunyuanImage2.1
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanImage
|
||||
|
||||
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanImage
|
||||
|
||||
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanImage
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanImage
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanImageRefiner
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ChromaTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
|
||||
|
||||
## ChromaTransformer2DModel
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HunyuanImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanImageTransformer2DModel
|
||||
|
||||
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanImageTransformer2DModel
|
||||
|
||||
[[autodoc]] HunyuanImageTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,36 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# SanaVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import SanaVideoTransformer3DModel
|
||||
import torch
|
||||
|
||||
transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] SanaVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BriaFiboTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
|
||||
|
||||
## BriaFiboTransformer2DModel
|
||||
|
||||
[[autodoc]] BriaFiboTransformer2DModel
|
||||
@@ -0,0 +1,45 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Bria Fibo
|
||||
|
||||
Text-to-image models have mastered imagination - but not control. FIBO changes that.
|
||||
|
||||
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
|
||||
|
||||
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
|
||||
|
||||
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
|
||||
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
|
||||
|
||||
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
|
||||
|
||||
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
|
||||
## BriaPipeline
|
||||
|
||||
[[autodoc]] BriaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Chroma is a text to image generation model based on Flux.
|
||||
|
||||
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
|
||||
Original model checkpoints for Chroma can be found here:
|
||||
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
|
||||
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
|
||||
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
|
||||
|
||||
> [!TIP]
|
||||
> Chroma can use all the same optimizations as Flux.
|
||||
|
||||
## Inference
|
||||
|
||||
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaPipeline
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = [
|
||||
@@ -63,10 +64,10 @@ Then run the following example
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
|
||||
model_id = "lodestones/Chroma"
|
||||
model_id = "lodestones/Chroma1-HD"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# HunyuanImage2.1
|
||||
|
||||
|
||||
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
|
||||
|
||||
HunyuanImage-2.1 comes in the following variants:
|
||||
|
||||
| model type | model id |
|
||||
|:----------:|:--------:|
|
||||
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
|
||||
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
|
||||
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
|
||||
|
||||
> [!TIP]
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## HunyuanImage-2.1
|
||||
|
||||
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
|
||||
pipe = HunyuanImagePipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
```
|
||||
|
||||
You can inspect the `guider` object:
|
||||
|
||||
```py
|
||||
>>> pipe.guider
|
||||
AdaptiveProjectedMixGuidance {
|
||||
"_class_name": "AdaptiveProjectedMixGuidance",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"adaptive_projected_guidance_momentum": -0.5,
|
||||
"adaptive_projected_guidance_rescale": 10.0,
|
||||
"adaptive_projected_guidance_scale": 10.0,
|
||||
"adaptive_projected_guidance_start_step": 5,
|
||||
"enabled": true,
|
||||
"eta": 0.0,
|
||||
"guidance_rescale": 0.0,
|
||||
"guidance_scale": 3.5,
|
||||
"start": 0.0,
|
||||
"stop": 1.0,
|
||||
"use_original_formulation": false
|
||||
}
|
||||
|
||||
State:
|
||||
step: None
|
||||
num_inference_steps: None
|
||||
timestep: None
|
||||
count_prepared: 0
|
||||
enabled: True
|
||||
num_conditions: 2
|
||||
momentum_buffer: None
|
||||
is_apg_enabled: False
|
||||
is_cfg_enabled: True
|
||||
```
|
||||
|
||||
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
|
||||
pipe = HunyuanImagePipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Update the guider configuration
|
||||
pipe.guider = pipe.guider.new(guidance_scale=5.0)
|
||||
|
||||
prompt = (
|
||||
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
|
||||
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
|
||||
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50,
|
||||
height=2048,
|
||||
width=2048,
|
||||
).images[0]
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
|
||||
## HunyuanImage-2.1-Distilled
|
||||
|
||||
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = (
|
||||
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
|
||||
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
|
||||
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
||||
)
|
||||
|
||||
out = pipe(
|
||||
prompt,
|
||||
num_inference_steps=8,
|
||||
distilled_guidance_scale=3.25,
|
||||
height=2048,
|
||||
width=2048,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
```
|
||||
|
||||
|
||||
## HunyuanImagePipeline
|
||||
|
||||
[[autodoc]] HunyuanImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanImageRefinerPipeline
|
||||
|
||||
[[autodoc]] HunyuanImageRefinerPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## HunyuanImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
|
||||
@@ -0,0 +1,149 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 5.0 Video
|
||||
|
||||
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
|
||||
|
||||
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
|
||||
|
||||
The model introduces several key innovations:
|
||||
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
|
||||
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
|
||||
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
|
||||
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
|
||||
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
|
||||
|
||||
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
|
||||
|
||||
## Available Models
|
||||
|
||||
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
|
||||
|
||||
| model_id | Description | Use Cases |
|
||||
|------------|-------------|-----------|
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
|
||||
|
||||
All models are available in 5-second and 10-second video generation versions.
|
||||
|
||||
## Kandinsky5T2VPipeline
|
||||
|
||||
[[autodoc]] Kandinsky5T2VPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Text-to-Video Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Generate video
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=512,
|
||||
width=768,
|
||||
num_frames=121, # ~5 seconds at 24fps
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### 10 second Models
|
||||
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
|
||||
|
||||
```python
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(
|
||||
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.transformer.set_attention_backend(
|
||||
"flex"
|
||||
) # <--- Sett attention bakend to Flex
|
||||
pipe.transformer.compile(
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=True
|
||||
) # <--- Compile with max-autotune-no-cudagraphs
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=512,
|
||||
width=768,
|
||||
num_frames=241,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### Diffusion Distilled model
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
|
||||
|
||||
```python
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
output = pipe(
|
||||
prompt="A beautiful sunset over mountains",
|
||||
num_inference_steps=16, # <--- Model is distilled in 16 steps
|
||||
guidance_scale=1.0, # <--- no CFG
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
```bibtex
|
||||
@misc{kandinsky2025,
|
||||
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
|
||||
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
|
||||
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
|
||||
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
|
||||
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
|
||||
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
|
||||
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
|
||||
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
|
||||
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
|
||||
year = 2025
|
||||
}
|
||||
```
|
||||
@@ -24,9 +24,6 @@ The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaVideoPipeline
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
|
||||
|
||||
This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
model_score = 30
|
||||
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_prompt = f" motion score: {model_score}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=50
|
||||
).frames[0]
|
||||
export_to_video(output, "sana-video-output.mp4", fps=16)
|
||||
```
|
||||
|
||||
## SanaVideoPipeline
|
||||
|
||||
[[autodoc]] SanaVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# LoopSequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
|
||||
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
|
||||
|
||||
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
|
||||
|
||||
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
|
||||
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
|
||||
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
|
||||
- `__call__` method defines the loop structure and iteration logic.
|
||||
|
||||
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
|
||||
|
||||
```py
|
||||
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
|
||||
```
|
||||
```
|
||||
|
||||
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
|
||||
|
||||
Use `InputParam` to define `intermediate_inputs`.
|
||||
|
||||
```py
|
||||
user_intermediate_inputs = [
|
||||
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
|
||||
Use `OutputParam` to define `intermediate_outputs`.
|
||||
|
||||
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
|
||||
|
||||
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
|
||||
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
|
||||
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
|
||||
2. Implement the computation logic on the `inputs`.
|
||||
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
|
||||
4. Return the components and state which becomes available to the next block.
|
||||
|
||||
@@ -76,7 +66,7 @@ def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs and intermediate_inputs
|
||||
# block_state contains all your inputs
|
||||
# Access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
@@ -112,4 +102,4 @@ def __call__(self, components, state):
|
||||
unet = components.unet
|
||||
vae = components.vae
|
||||
scheduler = components.scheduler
|
||||
```
|
||||
```
|
||||
|
||||
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# SequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
|
||||
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
|
||||
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
|
||||
|
||||
<hfoptions id="sequential">
|
||||
<hfoption id="InputBlock">
|
||||
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
|
||||
```py
|
||||
print(blocks)
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
|
||||
| attention family | main feature |
|
||||
|---|---|
|
||||
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
|
||||
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
|
||||
| SageAttention | quantizes attention to int8 |
|
||||
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
|
||||
| xFormers | memory-efficient attention with support for various attention kernels |
|
||||
@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folder as `--train_data_dir`
|
||||
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
|
||||
|
||||
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
|
||||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
#### Provide the dataset as a folder
|
||||
|
||||
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
|
||||
"""
|
||||
if tensor.ndim == 2:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
channels = tensor.shape[0]
|
||||
if channels == 3:
|
||||
return tensor
|
||||
if channels == 1:
|
||||
return tensor.repeat(3, 1, 1)
|
||||
if channels == 2:
|
||||
return torch.cat([tensor, tensor[:1]], dim=0)
|
||||
if channels > 3:
|
||||
return tensor[:3]
|
||||
raise ValueError(f"Unsupported number of channels: {channels}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -260,6 +278,11 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve_input_precision",
|
||||
action="store_true",
|
||||
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -453,19 +476,41 @@ def main(args):
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets and DataLoaders creation.
|
||||
spatial_augmentations = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
]
|
||||
|
||||
augmentations = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
spatial_augmentations
|
||||
+ [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
precision_augmentations = transforms.Compose(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
transforms.Lambda(_ensure_three_channels),
|
||||
transforms.ConvertImageDtype(torch.float32),
|
||||
]
|
||||
+ spatial_augmentations
|
||||
+ [transforms.Normalize([0.5], [0.5])]
|
||||
)
|
||||
|
||||
def transform_images(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
processed = []
|
||||
for image in examples["image"]:
|
||||
if not args.preserve_input_precision:
|
||||
processed.append(augmentations(image.convert("RGB")))
|
||||
else:
|
||||
precise_image = image
|
||||
if precise_image.mode == "P":
|
||||
precise_image = precise_image.convert("RGB")
|
||||
processed.append(precision_augmentations(precise_image))
|
||||
return {"input": processed}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaVideoPipeline,
|
||||
SanaVideoTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
def main(args):
|
||||
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
|
||||
|
||||
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
|
||||
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
|
||||
snapshot_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
file_path = hf_hub_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
else:
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
|
||||
all_state_dict = torch.load(file_path, weights_only=True)
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 8.0
|
||||
|
||||
# model config
|
||||
layer_num = 20
|
||||
# Positional embedding interpolation scale.
|
||||
qk_norm = True
|
||||
|
||||
# sample size
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Linear Attention is all you need 🤘
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.t_conv.weight"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 3.0,
|
||||
"attention_bias": False,
|
||||
"sample_size": sample_size,
|
||||
"patch_size": patch_size,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 1024,
|
||||
}
|
||||
|
||||
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
|
||||
|
||||
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
transformer = transformer.to(weight_dtype)
|
||||
|
||||
if not args.save_full_pipeline:
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
elif args.scheduler_type == "uni-pc":
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction",
|
||||
use_flow_sigmas=True,
|
||||
num_train_timesteps=1000,
|
||||
flow_shift=flow_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaVideoPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_size",
|
||||
default=480,
|
||||
type=int,
|
||||
choices=[480, 720],
|
||||
required=False,
|
||||
help="Video size of pretrained model, 480 or 720.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="SanaVideo",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaVideo",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
@@ -149,7 +149,9 @@ else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AdaptiveProjectedMixGuidance",
|
||||
"AutoGuidance",
|
||||
"BaseGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
@@ -184,6 +186,8 @@ else:
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
@@ -194,6 +198,7 @@ else:
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
"BriaTransformer2DModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
@@ -216,6 +221,7 @@ else:
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanImageTransformer2DModel",
|
||||
"HunyuanVideoFramepackTransformer3DModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
@@ -240,6 +246,7 @@ else:
|
||||
"QwenImageTransformer2DModel",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SanaVideoTransformer3DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -425,6 +432,7 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
@@ -462,6 +470,8 @@ else:
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanImagePipeline",
|
||||
"HunyuanImageRefinerPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
@@ -535,6 +545,7 @@ else:
|
||||
"SanaPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -849,7 +860,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AdaptiveProjectedMixGuidance,
|
||||
AutoGuidance,
|
||||
BaseGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
@@ -880,6 +893,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -890,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
@@ -912,6 +928,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
@@ -936,6 +953,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SD3Transformer2DModel,
|
||||
@@ -1091,6 +1109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
@@ -1128,6 +1147,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanImagePipeline,
|
||||
HunyuanImageRefinerPipeline,
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
@@ -1201,6 +1222,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
|
||||
@@ -14,28 +14,18 @@
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .guider_utils import BaseGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
@@ -76,19 +77,14 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
@@ -152,6 +148,44 @@ class MomentumBuffer:
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
||||
"""
|
||||
if isinstance(self.running_average, torch.Tensor):
|
||||
shape = tuple(self.running_average.shape)
|
||||
|
||||
# Calculate statistics
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": self.running_average.mean().item(),
|
||||
"std": self.running_average.std().item(),
|
||||
"min": self.running_average.min().item(),
|
||||
"max": self.running_average.max().item(),
|
||||
}
|
||||
|
||||
# Get a slice (max 3 elements per dimension)
|
||||
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
||||
sliced_data = self.running_average[slice_indices]
|
||||
|
||||
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
||||
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
||||
if len(slice_str) > 200: # Truncate if too long
|
||||
slice_str = slice_str[:200] + "..."
|
||||
|
||||
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
||||
|
||||
return (
|
||||
f"MomentumBuffer(\n"
|
||||
f" momentum={self.momentum},\n"
|
||||
f" shape={shape},\n"
|
||||
f" stats=[{stats_str}],\n"
|
||||
f" slice={slice_str}\n"
|
||||
f")"
|
||||
)
|
||||
else:
|
||||
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedMixGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
|
||||
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
|
||||
improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
|
||||
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
|
||||
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
|
||||
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
|
||||
used, and momentum buffer is updated).
|
||||
enabled (`bool`, defaults to `True`):
|
||||
Whether this guidance is enabled.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 3.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
adaptive_projected_guidance_scale: float = 10.0,
|
||||
adaptive_projected_guidance_momentum: float = -0.5,
|
||||
adaptive_projected_guidance_rescale: float = 10.0,
|
||||
eta: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
adaptive_projected_guidance_start_step: int = 5,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
# no guidance
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
|
||||
# CFG + update momentum buffer
|
||||
elif not self._is_apg_enabled():
|
||||
if self.momentum_buffer is not None:
|
||||
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
|
||||
# CFG + update momentum buffer
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
# APG
|
||||
elif self._is_apg_enabled():
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.adaptive_projected_guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled() or self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
return False
|
||||
|
||||
is_within_range = False
|
||||
if self._step is not None:
|
||||
is_within_range = self._step > self.adaptive_projected_guidance_start_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def get_state(self):
|
||||
state = super().get_state()
|
||||
state["momentum_buffer"] = self.momentum_buffer
|
||||
state["is_apg_enabled"] = self._is_apg_enabled()
|
||||
state["is_cfg_enabled"] = self._is_cfg_enabled()
|
||||
return state
|
||||
|
||||
|
||||
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
||||
"""
|
||||
if isinstance(self.running_average, torch.Tensor):
|
||||
shape = tuple(self.running_average.shape)
|
||||
|
||||
# Calculate statistics
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": self.running_average.mean().item(),
|
||||
"std": self.running_average.std().item(),
|
||||
"min": self.running_average.min().item(),
|
||||
"max": self.running_average.max().item(),
|
||||
}
|
||||
|
||||
# Get a slice (max 3 elements per dimension)
|
||||
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
||||
sliced_data = self.running_average[slice_indices]
|
||||
|
||||
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
||||
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
||||
if len(slice_str) > 200: # Truncate if too long
|
||||
slice_str = slice_str[:200] + "..."
|
||||
|
||||
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
||||
|
||||
return (
|
||||
f"MomentumBuffer(\n"
|
||||
f" momentum={self.momentum},\n"
|
||||
f" shape={shape},\n"
|
||||
f" stats=[{stats_str}],\n"
|
||||
f" slice={slice_str}\n"
|
||||
f")"
|
||||
)
|
||||
else:
|
||||
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
||||
|
||||
|
||||
def update_momentum_buffer(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
if momentum_buffer is not None:
|
||||
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
|
||||
diff = momentum_buffer.running_average
|
||||
else:
|
||||
diff = pred_cond - pred_uncond
|
||||
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
@@ -132,16 +133,11 @@ class AutoGuidance(BaseGuidance):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,43 +27,50 @@ if TYPE_CHECKING:
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
Implements Classifier-Free Guidance (CFG) for diffusion models.
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
|
||||
proposes scaling and shifting the conditional distribution based on the difference between conditional and
|
||||
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
Reference: https://huggingface.co/papers/2207.12598
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
|
||||
unconditional data, then combining predictions during inference. This allows trading off between quality (high
|
||||
guidance) and diversity (low guidance).
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
**Two CFG Formulations:**
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
1. **Original formulation** (from paper):
|
||||
```
|
||||
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
|
||||
```
|
||||
Moves conditional predictions further from unconditional ones.
|
||||
|
||||
2. **Diffusers-native formulation** (default, from Imagen paper):
|
||||
```
|
||||
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
|
||||
```
|
||||
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
|
||||
quality", "watermarks"). Equivalent in theory but more intuitive.
|
||||
|
||||
Use `use_original_formulation=True` to switch to the original formulation.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
|
||||
may reduce quality. Typical range: 1.0-20.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
|
||||
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
|
||||
to 1.0 (full rescaling).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
|
||||
diffusers-native formulation from the Imagen paper.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
|
||||
steps.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
|
||||
steps.
|
||||
enabled (`bool`, defaults to `True`):
|
||||
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
@@ -76,23 +83,19 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
# YiYi Notes: add default behavior for self._enabled == False
|
||||
if not self._enabled:
|
||||
pred = pred_cond
|
||||
|
||||
elif self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
|
||||
@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
enabled: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
super().__init__(min_start, max_stop, enabled)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
@@ -217,16 +218,11 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -40,7 +40,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
|
||||
logger.warning(
|
||||
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
|
||||
)
|
||||
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
@@ -48,7 +52,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
self._enabled = enabled
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
@@ -60,6 +64,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
"""
|
||||
Creates a copy of this guider instance, optionally with modified configuration parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
|
||||
returns an exact copy with the same configuration.
|
||||
|
||||
Returns:
|
||||
A new guider instance with the same (or updated) configuration.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Create a CFG guider
|
||||
guider = ClassifierFreeGuidance(guidance_scale=3.5)
|
||||
|
||||
# Create an exact copy
|
||||
same_guider = guider.new()
|
||||
|
||||
# Create a copy with different start step, keeping other config the same
|
||||
new_guider = guider.new(guidance_scale=5)
|
||||
```
|
||||
"""
|
||||
return self.__class__.from_config(self.config, **kwargs)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
@@ -72,42 +101,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
||||
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
||||
the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
||||
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
||||
conditional data identifier and the second element must be the unconditional data identifier or None.
|
||||
|
||||
Example:
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
|
||||
the __repr__ method. Returns:
|
||||
`Dict[str, Any]`: A dictionary containing the current state variables including:
|
||||
- step: Current inference step
|
||||
- num_inference_steps: Total number of inference steps
|
||||
- timestep: Current timestep tensor
|
||||
- count_prepared: Number of times prepare_models has been called
|
||||
- enabled: Whether the guidance is enabled
|
||||
- num_conditions: Number of conditions
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = (
|
||||
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
state = {
|
||||
"step": self._step,
|
||||
"num_inference_steps": self._num_inference_steps,
|
||||
"timestep": self._timestep,
|
||||
"count_prepared": self._count_prepared,
|
||||
"enabled": self._enabled,
|
||||
"num_conditions": self.num_conditions,
|
||||
}
|
||||
return state
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the guidance object including both config and current state.
|
||||
"""
|
||||
# Get ConfigMixin's __repr__
|
||||
str_repr = super().__repr__()
|
||||
|
||||
# Get current state
|
||||
state = self.get_state()
|
||||
|
||||
# Format each state variable on its own line with indentation
|
||||
state_lines = []
|
||||
for k, v in state.items():
|
||||
# Convert value to string and handle multi-line values
|
||||
v_str = str(v)
|
||||
if "\n" in v_str:
|
||||
# For multi-line values (like MomentumBuffer), indent subsequent lines
|
||||
v_lines = v_str.split("\n")
|
||||
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
|
||||
state_lines.append(f" {k}: {v_str}")
|
||||
|
||||
state_str = "\n".join(state_lines)
|
||||
|
||||
return f"{str_repr}\nState:\n{state_str}"
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
@@ -155,8 +194,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
@@ -182,21 +220,16 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError(
|
||||
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
||||
)
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
for key, value in data.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
if isinstance(value, torch.Tensor):
|
||||
data_batch[key] = value
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
data_batch[key] = value[tuple_index]
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
raise ValueError(f"Invalid value type: {type(value)}")
|
||||
except ValueError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
||||
@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -186,8 +182,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -182,8 +178,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -171,8 +167,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -58,23 +58,19 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
||||
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
|
||||
@@ -149,6 +150,14 @@ def _register_attention_processors_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanImageAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=HunyuanImageAttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_hunyuanimage import (
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanImage2.1
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanImageSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
|
||||
# not sure what this is yet.
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook):
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
logger.warning_once(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
return x
|
||||
else:
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
|
||||
@@ -1045,16 +1045,39 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
Args:
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
# 1. Cast the tensor to a larger integer type (e.g., int32)
|
||||
# to safely perform the multiplication by 256.
|
||||
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
|
||||
# 3. Cast the final result to the desired depth map type (uint16) if needed
|
||||
# before returning, though leaving it as int32/int64 is often safer
|
||||
# for return value from a library function.
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.to(torch.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# You may want to cast the final result to uint16, but casting to a
|
||||
# larger int type (like int32) is sufficient to fix the overflow.
|
||||
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.to(original_dtype)
|
||||
|
||||
elif isinstance(image, np.ndarray):
|
||||
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.astype(np.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.astype(original_dtype)
|
||||
else:
|
||||
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
|
||||
@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
"time_projection.1.diff_b"
|
||||
)
|
||||
|
||||
if any("head.head" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
|
||||
if any("head.head" in k for k in original_state_dict):
|
||||
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_up_key}.weight"
|
||||
)
|
||||
if "head.head.diff_b" in original_state_dict:
|
||||
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
|
||||
|
||||
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
|
||||
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
|
||||
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
|
||||
# an identity.
|
||||
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
|
||||
if f"head.head.{lora_down_key}.weight" in state_dict:
|
||||
logger.info(
|
||||
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
|
||||
)
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
|
||||
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
|
||||
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
|
||||
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
|
||||
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
|
||||
).T
|
||||
|
||||
for text_time in ["text_embedding", "time_embedding"]:
|
||||
if any(text_time in k for k in original_state_dict):
|
||||
for b_n in [0, 2]:
|
||||
@@ -2193,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
|
||||
@@ -4940,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
|
||||
@@ -36,6 +36,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
@@ -82,6 +84,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
@@ -91,6 +94,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
@@ -98,6 +102,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -133,6 +138,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -169,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .transformers import (
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
@@ -182,6 +190,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
@@ -196,6 +205,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
StableAudioDiTModel,
|
||||
|
||||
@@ -27,6 +27,8 @@ if torch.distributed.is_available():
|
||||
|
||||
from ..utils import (
|
||||
get_logger,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
@@ -47,6 +49,7 @@ if TYPE_CHECKING:
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
|
||||
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||
_REQUIRED_AITER_VERSION = "0.1.5"
|
||||
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||
_REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
@@ -54,6 +57,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
||||
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
||||
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
||||
@@ -78,6 +82,12 @@ else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
if not is_kernels_available():
|
||||
raise ImportError(
|
||||
@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
|
||||
|
||||
# `aiter`
|
||||
AITER = "aiter"
|
||||
|
||||
# PyTorch native
|
||||
FLEX = "flex"
|
||||
NATIVE = "native"
|
||||
@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.SAGE,
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
@@ -630,6 +649,86 @@ def _(
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
def _native_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
# Native attention does not return_lse
|
||||
if return_lse:
|
||||
raise ValueError("Native attention does not support return_lse=True")
|
||||
|
||||
# used for backward pass
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.attn_mask = attn_mask
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.enable_gqa = enable_gqa
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _native_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
|
||||
# forward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
@@ -1397,6 +1496,47 @@ def _flash_varlen_attention_3(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.AITER,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_lse=True,
|
||||
)
|
||||
else:
|
||||
out = aiter_flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLEX,
|
||||
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
||||
@@ -1463,6 +1603,7 @@ def _native_flex_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -1478,18 +1619,35 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op=_native_attention_forward_op,
|
||||
backward_op=_native_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -147,14 +147,13 @@ class AutoModel(ConfigMixin):
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -205,7 +204,6 @@ class AutoModel(ConfigMixin):
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
|
||||
@@ -0,0 +1,709 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanImageResnetBlock(nn.Module):
|
||||
r"""
|
||||
Residual block with two convolutions and optional channel change.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# layers
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
def forward(self, x):
|
||||
# Apply shortcut connection
|
||||
residual = x
|
||||
|
||||
# First normalization and activation
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
x = self.conv_shortcut(x)
|
||||
# Add residual connection
|
||||
return x + residual
|
||||
|
||||
|
||||
class HunyuanImageAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Self-attention with a single head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
|
||||
# layers
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.to_q = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.to_k = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.to_v = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.proj = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.norm(x)
|
||||
|
||||
# compute query, key, value
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
batch_size, channels, height, width = query.shape
|
||||
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
|
||||
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
|
||||
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
|
||||
# output projection
|
||||
x = self.proj(x)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class HunyuanImageDownsample(nn.Module):
|
||||
"""
|
||||
Downsampling block for spatial reduction.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
factor = 4
|
||||
if out_channels % factor != 0:
|
||||
raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv(x)
|
||||
|
||||
B, C, H, W = h.shape
|
||||
h = h.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||
h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
|
||||
h = h.reshape(B, 4 * C, H // 2, W // 2)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||
shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
|
||||
shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
|
||||
|
||||
B, C, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageUpsample(nn.Module):
|
||||
"""
|
||||
Upsampling block for spatial expansion.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
factor = 4
|
||||
self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv(x)
|
||||
|
||||
B, C, H, W = h.shape
|
||||
h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
|
||||
h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
|
||||
h = h.reshape(B, C // 4, H * 2, W * 2)
|
||||
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
|
||||
B, C, H, W = shortcut.shape
|
||||
shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
|
||||
shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
|
||||
shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageMidBlock(nn.Module):
|
||||
"""
|
||||
Middle block for HunyuanImageVAE encoder and decoder.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
num_layers (int): Number of layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, num_layers: int = 1):
|
||||
super().__init__()
|
||||
|
||||
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
|
||||
|
||||
attentions = []
|
||||
for _ in range(num_layers):
|
||||
attentions.append(HunyuanImageAttentionBlock(in_channels))
|
||||
resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.resnets[0](x)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
x = attn(x)
|
||||
x = resnet(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanImageEncoder2D(nn.Module):
|
||||
r"""
|
||||
Encoder network that compresses input to latent representation.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
z_channels (int): Number of latent channels.
|
||||
block_out_channels (list of int): Output channels for each block.
|
||||
num_res_blocks (int): Number of residual blocks per block.
|
||||
spatial_compression_ratio (int): Spatial downsampling factor.
|
||||
non_linearity (str): Type of non-linearity to use. Default is "silu".
|
||||
downsample_match_channel (bool): Whether to match channels during downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
spatial_compression_ratio: int,
|
||||
non_linearity: str = "silu",
|
||||
downsample_match_channel: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if block_out_channels[-1] % (2 * z_channels) != 0:
|
||||
raise ValueError(
|
||||
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
|
||||
)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
|
||||
self.group_size = block_out_channels[-1] // (2 * z_channels)
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# init block
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
block_in_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
block_out_channel = block_out_channels[i]
|
||||
# residual blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(
|
||||
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
|
||||
)
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
# downsample block
|
||||
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
|
||||
if downsample_match_channel:
|
||||
block_out_channel = block_out_channels[i + 1]
|
||||
self.down_blocks.append(
|
||||
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
|
||||
)
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
|
||||
|
||||
# output blocks
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## downsamples
|
||||
for down_block in self.down_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(down_block, x)
|
||||
else:
|
||||
x = down_block(x)
|
||||
|
||||
## middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(self.mid_block, x)
|
||||
else:
|
||||
x = self.mid_block(x)
|
||||
|
||||
## head
|
||||
B, C, H, W = x.shape
|
||||
residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv_out(x)
|
||||
return x + residual
|
||||
|
||||
|
||||
class HunyuanImageDecoder2D(nn.Module):
|
||||
r"""
|
||||
Decoder network that reconstructs output from latent representation.
|
||||
|
||||
Args:
|
||||
z_channels : int
|
||||
Number of latent channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
block_out_channels : Tuple[int, ...]
|
||||
Output channels for each block.
|
||||
num_res_blocks : int
|
||||
Number of residual blocks per block.
|
||||
spatial_compression_ratio : int
|
||||
Spatial upsampling factor.
|
||||
upsample_match_channel : bool
|
||||
Whether to match channels during upsampling.
|
||||
non_linearity (str): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
spatial_compression_ratio: int,
|
||||
upsample_match_channel: bool = True,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
if block_out_channels[0] % z_channels != 0:
|
||||
raise ValueError(
|
||||
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
|
||||
)
|
||||
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.repeat = block_out_channels[0] // z_channels
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
|
||||
|
||||
# Upsampling blocks
|
||||
block_in_channel = block_out_channels[0]
|
||||
self.up_blocks = nn.ModuleList()
|
||||
for i in range(len(block_out_channels)):
|
||||
block_out_channel = block_out_channels[i]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
self.up_blocks.append(
|
||||
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
|
||||
)
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
|
||||
if upsample_match_channel:
|
||||
block_out_channel = block_out_channels[i + 1]
|
||||
self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
|
||||
block_in_channel = block_out_channel
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid_block, h)
|
||||
else:
|
||||
h = self.mid_block(h)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(up_block, h)
|
||||
else:
|
||||
h = up_block(h)
|
||||
h = self.norm_out(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model for 2D images with spatial tiling support.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
# fmt: off
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
latent_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: int,
|
||||
spatial_compression_ratio: int,
|
||||
sample_size: int,
|
||||
scaling_factor: float = None,
|
||||
downsample_match_channel: bool = True,
|
||||
upsample_match_channel: bool = True,
|
||||
) -> None:
|
||||
# fmt: on
|
||||
super().__init__()
|
||||
|
||||
self.encoder = HunyuanImageEncoder2D(
|
||||
in_channels=in_channels,
|
||||
z_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
num_res_blocks=layers_per_block,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
downsample_match_channel=downsample_match_channel,
|
||||
)
|
||||
|
||||
self.decoder = HunyuanImageDecoder2D(
|
||||
z_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
num_res_blocks=layers_per_block,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
upsample_match_channel=upsample_match_channel,
|
||||
)
|
||||
|
||||
# Tiling and slicing configuration
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Tiling parameters
|
||||
self.tile_sample_min_size = sample_size
|
||||
self.tile_latent_min_size = sample_size // spatial_compression_ratio
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_size: Optional[int] = None,
|
||||
tile_overlap_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
|
||||
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
|
||||
allow processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_size (`int`, *optional*):
|
||||
The minimum size required for a sample to be separated into tiles across the spatial dimension.
|
||||
tile_overlap_factor (`float`, *optional*):
|
||||
The overlap factor required for a latent to be separated into tiles across the spatial dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor):
|
||||
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
||||
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode input using spatial tiling strategy.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded images.
|
||||
"""
|
||||
_, _, _, height, width = x.shape
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_size):
|
||||
row = []
|
||||
for j in range(0, width, overlap_size):
|
||||
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode latent using spatial tiling strategy.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
_, _, height, width = z.shape
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_size):
|
||||
row = []
|
||||
for j in range(0, width, overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
|
||||
return dec
|
||||
@@ -0,0 +1,934 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanImageRefinerCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
bias: bool = True,
|
||||
pad_mode: str = "replicate",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[2] - 1,
|
||||
0,
|
||||
)
|
||||
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class HunyuanImageRefinerRMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanImageRefinerAttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
|
||||
|
||||
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
batch_size, channels, frames, height, width = query.shape
|
||||
|
||||
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
|
||||
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
|
||||
|
||||
# batch_size, 1, frames * height * width, channels
|
||||
|
||||
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
|
||||
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_upsample = add_temporal_upsample
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
|
||||
|
||||
Args:
|
||||
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
|
||||
r1: temporal upsampling factor
|
||||
r2: height upsampling factor
|
||||
r3: width upsampling factor
|
||||
"""
|
||||
b, packed_c, f, h, w = tensor.shape
|
||||
factor = r1 * r2 * r3
|
||||
c = packed_c // factor
|
||||
|
||||
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
|
||||
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_upsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_upsample:
|
||||
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
|
||||
h = h[:, : h.shape[1] // 2]
|
||||
|
||||
# shortcut computation
|
||||
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
|
||||
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
|
||||
|
||||
else:
|
||||
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageRefinerDownsampleDCAE(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
||||
assert out_channels % factor == 0
|
||||
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_downsample = add_temporal_downsample
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
|
||||
|
||||
This packs spatial/temporal dimensions into channels (opposite of upsample)
|
||||
"""
|
||||
b, c, packed_f, packed_h, packed_w = tensor.shape
|
||||
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
|
||||
|
||||
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
|
||||
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_downsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_downsample:
|
||||
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
|
||||
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
|
||||
h = torch.cat([h, h], dim=1)
|
||||
# shortcut computation
|
||||
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
|
||||
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
|
||||
B, C, T, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
|
||||
else:
|
||||
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
||||
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
||||
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
|
||||
B, C, T, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
||||
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanImageRefinerResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
non_linearity: str = "swish",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
|
||||
self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
|
||||
self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class HunyuanImageRefinerMidBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
add_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.add_attention = add_attention
|
||||
|
||||
# There is always at least one resnet
|
||||
resnets = [
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerDownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
downsample_out_channels: Optional[int] = None,
|
||||
add_temporal_downsample: int = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample_out_channels is not None:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageRefinerDownsampleDCAE(
|
||||
out_channels,
|
||||
out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerUpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
upsample_out_channels: Optional[int] = None,
|
||||
add_temporal_upsample: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
HunyuanImageRefinerResnetBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample_out_channels is not None:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageRefinerUpsampleDCAE(
|
||||
out_channels,
|
||||
out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
|
||||
else:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerEncoder3D(nn.Module):
|
||||
r"""
|
||||
3D vae encoder for HunyuanImageRefiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 64,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 16,
|
||||
downsample_match_channel: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.group_size = block_out_channels[-1] // self.out_channels
|
||||
|
||||
self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
|
||||
output_channel = block_out_channels[i]
|
||||
if not add_spatial_downsample:
|
||||
down_block = HunyuanImageRefinerDownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=None,
|
||||
add_temporal_downsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
else:
|
||||
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
|
||||
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
|
||||
down_block = HunyuanImageRefinerDownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
input_channel = downsample_out_channels
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
|
||||
|
||||
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
|
||||
batch_size, _, frame, height, width = hidden_states.shape
|
||||
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states += short_cut
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRefinerDecoder3D(nn.Module):
|
||||
r"""
|
||||
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
upsample_match_channel: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.repeat = block_out_channels[0] // self.in_channels
|
||||
|
||||
self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
|
||||
|
||||
# up
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
|
||||
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
|
||||
if add_spatial_upsample or add_temporal_upsample:
|
||||
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
|
||||
up_block = HunyuanImageRefinerUpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
input_channel = upsample_out_channels
|
||||
else:
|
||||
up_block = HunyuanImageRefinerUpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=None,
|
||||
add_temporal_upsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# out
|
||||
self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanImage-2.1 Refiner.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
downsample_match_channel: bool = True,
|
||||
upsample_match_channel: bool = True,
|
||||
scaling_factor: float = 1.03682,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = HunyuanImageRefinerEncoder3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels * 2,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
downsample_match_channel=downsample_match_channel,
|
||||
)
|
||||
|
||||
self.decoder = HunyuanImageRefinerDecoder3D(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
upsample_match_channel=upsample_match_channel,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 256
|
||||
self.tile_sample_min_width = 256
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
tile_overlap_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_sample_stride_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension.
|
||||
tile_sample_stride_width (`int`, *optional*):
|
||||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
||||
artifacts produced across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
return dec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
|
||||
row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
|
||||
_, _, _, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
|
||||
row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + tile_latent_min_height,
|
||||
j : j + tile_latent_min_width,
|
||||
]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return dec
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
@@ -1337,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
||||
tile_sample_stride_height = self.tile_sample_stride_height
|
||||
tile_sample_stride_width = self.tile_sample_stride_width
|
||||
if self.config.patch_size is not None:
|
||||
sample_height = sample_height // self.config.patch_size
|
||||
sample_width = sample_width // self.config.patch_size
|
||||
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
|
||||
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
|
||||
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
|
||||
else:
|
||||
blend_height = self.tile_sample_min_height - tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width - tile_sample_stride_width
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
@@ -1353,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
self._conv_idx = [0]
|
||||
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
decoded = self.decoder(
|
||||
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
|
||||
)
|
||||
time.append(decoded)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
@@ -1369,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
||||
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
||||
|
||||
if self.config.patch_size is not None:
|
||||
dec = unpatchify(dec, patch_size=self.config.patch_size)
|
||||
|
||||
dec = torch.clamp(dec, min=-1.0, max=1.0)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@@ -286,11 +286,9 @@ class Decoder(nn.Module):
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
@@ -298,7 +296,6 @@ class Decoder(nn.Module):
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
|
||||
@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
|
||||
"""
|
||||
This function generates 1D positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension `D`
|
||||
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
|
||||
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
|
||||
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
|
||||
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
|
||||
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
||||
# Auto-detect appropriate dtype if not specified
|
||||
if dtype is None:
|
||||
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ if is_torch_available():
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_bria import BriaTransformer2DModel
|
||||
from .transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from .transformer_chroma import ChromaTransformer2DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
@@ -27,6 +28,7 @@ if is_torch_available():
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
@@ -34,6 +36,7 @@ if is_torch_available():
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sana_video import SanaVideoTransformer3DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -0,0 +1,655 @@
|
||||
# Copyright (c) Bria.ai. All rights reserved.
|
||||
#
|
||||
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
||||
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
||||
#
|
||||
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
||||
# indicate if changes were made, and do not use the material for commercial purposes.
|
||||
#
|
||||
# See the license for further details.
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_bria import BriaAttnProcessor
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
if attn.fused_projections:
|
||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
|
||||
class BriaFiboAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "BriaFiboAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
|
||||
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = BriaFiboAttnProcessor
|
||||
_available_processors = [BriaFiboAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
context_pre_only: Optional[bool] = None,
|
||||
pre_only: bool = False,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class BriaFiboEmbedND(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BriaFiboSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
processor = BriaAttnProcessor()
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BriaFiboTextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
|
||||
class BriaFiboTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
self.attn = BriaFiboAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=BriaFiboAttnProcessor(),
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class BriaFiboTimesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.time_theta = time_theta
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
max_period=self.time_theta,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class BriaFiboTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = BriaFiboTimesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
Parameters:
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
||||
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
||||
...
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = None,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
rope_theta=10000,
|
||||
time_theta=10000,
|
||||
text_encoder_dim: int = 2048,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BriaFiboTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BriaFiboSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
caption_projection = [
|
||||
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
|
||||
for i in range(self.config.num_layers + self.config.num_single_layers)
|
||||
]
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
text_encoder_layers: list = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
if guidance:
|
||||
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if len(txt_ids.shape) == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if len(img_ids.shape) == 3:
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
new_text_encoder_layers = []
|
||||
for i, text_encoder_layer in enumerate(text_encoder_layers):
|
||||
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
|
||||
new_text_encoder_layers.append(text_encoder_layer)
|
||||
text_encoder_layers = new_text_encoder_layers
|
||||
|
||||
block_id = 0
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
current_text_encoder_layer = text_encoder_layers[block_id]
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
|
||||
)
|
||||
block_id += 1
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
current_text_encoder_layer = text_encoder_layers[block_id]
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
|
||||
)
|
||||
block_id += 1
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
|
||||
"""
|
||||
The Transformer model introduced in Flux, modified for Chroma.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma
|
||||
Reference: https://huggingface.co/lodestones/Chroma1-HD
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -717,7 +717,11 @@ class FluxTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
@@ -42,6 +43,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideoAttnProcessor2_0:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
@@ -64,9 +68,9 @@ class HunyuanVideoAttnProcessor2_0:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
@@ -81,21 +85,29 @@ class HunyuanVideoAttnProcessor2_0:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
query = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
||||
query[:, :, -encoder_hidden_states.shape[1] :],
|
||||
apply_rotary_emb(
|
||||
query[:, : -encoder_hidden_states.shape[1]],
|
||||
image_rotary_emb,
|
||||
sequence_dim=1,
|
||||
),
|
||||
query[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=2,
|
||||
dim=1,
|
||||
)
|
||||
key = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
||||
key[:, :, -encoder_hidden_states.shape[1] :],
|
||||
apply_rotary_emb(
|
||||
key[:, : -encoder_hidden_states.shape[1]],
|
||||
image_rotary_emb,
|
||||
sequence_dim=1,
|
||||
),
|
||||
key[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=2,
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# 4. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
@@ -103,24 +115,31 @@ class HunyuanVideoAttnProcessor2_0:
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=2)
|
||||
key = torch.cat([key, encoder_key], dim=2)
|
||||
value = torch.cat([value, encoder_value], dim=2)
|
||||
query = torch.cat([query, encoder_query], dim=1)
|
||||
key = torch.cat([key, encoder_key], dim=1)
|
||||
value = torch.cat([value, encoder_value], dim=1)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
|
||||
@@ -0,0 +1,971 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanImageAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
query = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(
|
||||
query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
|
||||
),
|
||||
query[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
key = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
|
||||
key[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# 4. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=1)
|
||||
key = torch.cat([key, encoder_key], dim=1)
|
||||
value = torch.cat([value, encoder_value], dim=1)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanImagePatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
if len(patch_size) == 2:
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
elif len(patch_size) == 3:
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
else:
|
||||
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageByT5TextProjection(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_size: int, out_features: int):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(in_features)
|
||||
self.linear_1 = nn.Linear(in_features, hidden_size)
|
||||
self.linear_2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear_3 = nn.Linear(hidden_size, out_features)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(encoder_hidden_states)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_3(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageAdaNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or 2 * in_features
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
def forward(
|
||||
self, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
temb = self.linear(self.nonlinearity(temb))
|
||||
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
||||
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
guidance_embeds: bool = False,
|
||||
use_meanflow: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_meanflow = use_meanflow
|
||||
|
||||
self.time_proj_r = None
|
||||
self.timestep_embedder_r = None
|
||||
if use_meanflow:
|
||||
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_embedder = None
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
timestep_r: Optional[torch.Tensor] = None,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
|
||||
|
||||
if timestep_r is not None:
|
||||
timesteps_proj_r = self.time_proj_r(timestep_r)
|
||||
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
|
||||
timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
|
||||
|
||||
if self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
# IndividualTokenRefinerBlock
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int, # 28
|
||||
attention_head_dim: int, # 128
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
||||
|
||||
self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
gate_msa, gate_mlp = self.norm_out(temb)
|
||||
hidden_states = hidden_states + attn_output * gate_msa
|
||||
|
||||
ff_output = self.ff(self.norm2(hidden_states))
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageIndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.refiner_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageIndividualTokenRefinerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self_attn_mask = None
|
||||
if attention_mask is not None:
|
||||
batch_size = attention_mask.shape[0]
|
||||
seq_len = attention_mask.shape[1]
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
self_attn_mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.refiner_blocks:
|
||||
hidden_states = block(hidden_states, temb, self_attn_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# txt_in
|
||||
class HunyuanImageTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
||||
)
|
||||
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.token_refiner = HunyuanImageIndividualTokenRefiner(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_layers=num_layers,
|
||||
mlp_width_ratio=mlp_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is None:
|
||||
pooled_hidden_states = hidden_states.mean(dim=1)
|
||||
else:
|
||||
original_dtype = hidden_states.dtype
|
||||
mask_float = attention_mask.float().unsqueeze(-1)
|
||||
pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
pooled_hidden_states = pooled_hidden_states.to(original_dtype)
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanImageRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
|
||||
raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
|
||||
|
||||
if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
|
||||
raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
|
||||
|
||||
if not len(patch_size) == len(rope_dim):
|
||||
raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.rope_dim = rope_dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if hidden_states.ndim == 5:
|
||||
_, _, frame, height, width = hidden_states.shape
|
||||
patch_size_frame, patch_size_height, patch_size_width = self.patch_size
|
||||
rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
|
||||
elif hidden_states.ndim == 4:
|
||||
_, _, height, width = hidden_states.shape
|
||||
patch_size_height, patch_size_width = self.patch_size
|
||||
rope_sizes = [height // patch_size_height, width // patch_size_width]
|
||||
else:
|
||||
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
|
||||
|
||||
axes_grids = []
|
||||
for i in range(len(rope_sizes)):
|
||||
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
||||
axes_grids.append(grid)
|
||||
grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
|
||||
grid = torch.stack(grid, dim=0) # [2, H, W]
|
||||
|
||||
freqs = []
|
||||
for i in range(len(rope_sizes)):
|
||||
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
||||
freqs.append(freq)
|
||||
|
||||
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanImageSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
mlp_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
bias=True,
|
||||
processor=HunyuanImageAttnProcessor(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
||||
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states = (
|
||||
norm_hidden_states[:, :-text_seq_length, :],
|
||||
norm_hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
|
||||
# 2. Attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, :-text_seq_length, :],
|
||||
hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanImageAttnProcessor(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of dual-stream blocks to use.
|
||||
num_single_layers (`int`, defaults to `40`):
|
||||
The number of layers of single-stream blocks to use.
|
||||
num_refiner_layers (`int`, defaults to `2`):
|
||||
The number of layers of refiner blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
qk_norm (`str`, defaults to `rms_norm`):
|
||||
The normalization to use for the query and key projections in the attention layers.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings in the model.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The dimension of the pooled projection of the text embeddings.
|
||||
rope_theta (`float`, defaults to `256.0`):
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
image_condition_type (`str`, *optional*, defaults to `None`):
|
||||
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
||||
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
||||
tokens in the latent stream and apply conditioning.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
||||
_no_split_modules = [
|
||||
"HunyuanImageTransformerBlock",
|
||||
"HunyuanImageSingleTransformerBlock",
|
||||
"HunyuanImagePatchEmbed",
|
||||
"HunyuanImageTokenRefiner",
|
||||
]
|
||||
_repeated_blocks = [
|
||||
"HunyuanImageTransformerBlock",
|
||||
"HunyuanImageSingleTransformerBlock",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
out_channels: int = 64,
|
||||
num_attention_heads: int = 28,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 20,
|
||||
num_single_layers: int = 40,
|
||||
num_refiner_layers: int = 2,
|
||||
mlp_ratio: float = 4.0,
|
||||
patch_size: Tuple[int, int] = (1, 1),
|
||||
qk_norm: str = "rms_norm",
|
||||
guidance_embeds: bool = False,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: Optional[int] = None,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (64, 64),
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
|
||||
raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Latent and condition embedders
|
||||
self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
|
||||
self.context_embedder = HunyuanImageTokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
if text_embed_2_dim is not None:
|
||||
self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
|
||||
else:
|
||||
self.context_embedder_2 = None
|
||||
|
||||
self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanImageSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
timestep_r: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask_2: Optional[torch.Tensor] = None,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
if hidden_states.ndim == 4:
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
sizes = (height, width)
|
||||
elif hidden_states.ndim == 5:
|
||||
batch_size, channels, frame, height, width = hidden_states.shape
|
||||
sizes = (frame, height, width)
|
||||
else:
|
||||
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
|
||||
|
||||
post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
encoder_attention_mask = encoder_attention_mask.bool()
|
||||
temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
|
||||
if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
|
||||
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
|
||||
|
||||
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
|
||||
|
||||
# reorder and combine text tokens: combine valid tokens first, then padding
|
||||
new_encoder_hidden_states = []
|
||||
new_encoder_attention_mask = []
|
||||
|
||||
for text, text_mask, text_2, text_mask_2 in zip(
|
||||
encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
|
||||
):
|
||||
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
|
||||
new_encoder_hidden_states.append(
|
||||
torch.cat(
|
||||
[
|
||||
text_2[text_mask_2], # valid byt5
|
||||
text[text_mask], # valid mllm
|
||||
text_2[~text_mask_2], # invalid byt5
|
||||
text[~text_mask], # invalid mllm
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply same reordering to attention masks
|
||||
new_encoder_attention_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
text_mask_2[text_mask_2],
|
||||
text_mask[text_mask],
|
||||
text_mask_2[~text_mask_2],
|
||||
text_mask[~text_mask],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
|
||||
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
|
||||
|
||||
attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
# 3. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. unpatchify
|
||||
# reshape: [batch_size, *post_patch_dims, channels, *patch_size]
|
||||
out_channels = self.config.out_channels
|
||||
reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
|
||||
hidden_states = hidden_states.reshape(*reshape_dims)
|
||||
|
||||
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
|
||||
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
|
||||
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
|
||||
ndim = len(post_patch_sizes)
|
||||
permute_pattern = [0, ndim + 1] # batch, channels
|
||||
for i in range(ndim):
|
||||
permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
|
||||
hidden_states = hidden_states.permute(*permute_pattern)
|
||||
|
||||
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
|
||||
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
|
||||
final_dims = [batch_size, out_channels] + [
|
||||
post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
|
||||
]
|
||||
hidden_states = hidden_states.reshape(*final_dims)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -324,6 +324,7 @@ class Kandinsky5AttnProcessor:
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -335,6 +336,7 @@ class Kandinsky5AttnProcessor:
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(-2, -1)
|
||||
|
||||
attn_out = attn.out_layer(hidden_states)
|
||||
|
||||
@@ -0,0 +1,703 @@
|
||||
# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GLUMBTempConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: Optional[str] = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
self.conv_temp = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
batch_size, num_frames, height, width, num_channels = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
# Temporal aggregation
|
||||
hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaLinearAttnProcessor3_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product linear attention.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
# B,N,H,C
|
||||
|
||||
query = F.relu(query)
|
||||
key = F.relu(key)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query_rotate = apply_rotary_emb(query, *rotary_emb)
|
||||
key_rotate = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# B,H,C,N
|
||||
query = query.permute(0, 2, 3, 1)
|
||||
key = key.permute(0, 2, 3, 1)
|
||||
query_rotate = query_rotate.permute(0, 2, 3, 1)
|
||||
key_rotate = key_rotate.permute(0, 2, 3, 1)
|
||||
value = value.permute(0, 2, 3, 1)
|
||||
|
||||
query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
|
||||
|
||||
z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15)
|
||||
|
||||
scores = torch.matmul(value, key_rotate.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query_rotate)
|
||||
|
||||
hidden_states = hidden_states * z
|
||||
# B,H,C,N
|
||||
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
|
||||
class SanaModulatedNorm(nn.Module):
|
||||
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.guidance_condition_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
|
||||
return self.linear(self.silu(conditioning)), conditioning
|
||||
|
||||
|
||||
class SanaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block introduced in [Sana-Video](https://huggingface.co/papers/2509.24695).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2240,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
dropout: float = 0.0,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
attention_bias: bool = True,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 3.0,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
kv_heads=num_attention_heads if qk_norm is not None else None,
|
||||
qk_norm=qk_norm,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=SanaLinearAttnProcessor3_0(),
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
qk_norm=qk_norm,
|
||||
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=SanaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = GLUMBTempConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
frames: int = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# 1. Modulation
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 2. Self Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention
|
||||
if self.attn2 is not None:
|
||||
attn_output = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `20`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `112`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
|
||||
The number of heads to use for cross-attention.
|
||||
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
|
||||
The number of channels in each head for cross-attention.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `2240`):
|
||||
The number of channels in the cross-attention output.
|
||||
caption_channels (`int`, defaults to `2304`):
|
||||
The number of channels in the caption embeddings.
|
||||
mlp_ratio (`float`, defaults to `2.5`):
|
||||
The expansion ratio to use in the GLUMBConv layer.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether to use bias in the attention layer.
|
||||
sample_size (`int`, defaults to `32`):
|
||||
The base size of the input latent.
|
||||
patch_size (`int`, defaults to `1`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
norm_elementwise_affine (`bool`, defaults to `False`):
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaVideoTransformerBlock", "SanaModulatedNorm"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
num_layers: int = 20,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
caption_channels: int = 2304,
|
||||
mlp_ratio: float = 2.5,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 30,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
if guidance_embeds:
|
||||
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
||||
else:
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaVideoTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output blocks
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -555,6 +555,9 @@ class WanTransformer3DModel(
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
"": {
|
||||
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -164,7 +164,11 @@ class AutoOffloadStrategy:
|
||||
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
try:
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
|
||||
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
@@ -699,6 +703,8 @@ class ComponentsManager:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
||||
|
||||
# TODO: add a warning if mem_get_info isn't available on `device`.
|
||||
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
@@ -598,7 +598,7 @@ class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
|
||||
and getattr(block_state, "image_width", None) is not None
|
||||
):
|
||||
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
|
||||
img_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
None, image_latent_height // 2, image_latent_width // 2, device, dtype
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
@@ -141,7 +141,7 @@ class FluxKontextLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
|
||||
@@ -95,7 +95,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
@@ -143,10 +143,6 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def __init__(self, _auto_resize=True):
|
||||
self._auto_resize = _auto_resize
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
@@ -167,7 +163,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("image")]
|
||||
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
@@ -195,7 +191,8 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
img = images[0]
|
||||
image_height, image_width = components.image_processor.get_default_height_width(img)
|
||||
aspect_ratio = image_width / image_height
|
||||
if self._auto_resize:
|
||||
_auto_resize = block_state._auto_resize
|
||||
if _auto_resize:
|
||||
# Kontext is trained on specific resolutions, using one of them is recommended
|
||||
_, image_width, image_height = min(
|
||||
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
||||
|
||||
@@ -112,6 +112,10 @@ class FluxTextInputStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
|
||||
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, -1
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -130,8 +130,14 @@ class PipelineState:
|
||||
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
|
||||
intermediates dict.
|
||||
"""
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
# Use object.__getattribute__ to avoid infinite recursion during deepcopy
|
||||
try:
|
||||
values = object.__getattribute__(self, "values")
|
||||
except AttributeError:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
if name in values:
|
||||
return values[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __repr__(self):
|
||||
@@ -299,15 +305,15 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"local_dir",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
@@ -325,11 +331,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
|
||||
block_kwargs = {
|
||||
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
|
||||
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
|
||||
}
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
@@ -2125,8 +2130,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
component_load_kwargs[key] = value["default"]
|
||||
try:
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create component '{name}': {e}")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
# Register all components at once
|
||||
self.register_components(**components_to_register)
|
||||
@@ -2492,6 +2502,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
else:
|
||||
state = deepcopy(state)
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
@@ -238,19 +238,27 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"encoder_hidden_states_mask": (
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
@@ -328,19 +336,27 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"encoder_hidden_states_mask": (
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
|
||||
@@ -201,27 +201,41 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"time_ids": ("add_time_ids", "negative_add_time_ids"),
|
||||
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"time_ids": (
|
||||
getattr(block_state, "add_time_ids", None),
|
||||
getattr(block_state, "negative_add_time_ids", None),
|
||||
),
|
||||
"text_embeds": (
|
||||
getattr(block_state, "pooled_prompt_embeds", None),
|
||||
getattr(block_state, "negative_pooled_prompt_embeds", None),
|
||||
),
|
||||
"image_embeds": (
|
||||
getattr(block_state, "ip_adapter_embeds", None),
|
||||
getattr(block_state, "negative_ip_adapter_embeds", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.unet)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
|
||||
# Predict the noise residual
|
||||
@@ -344,11 +358,23 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"time_ids": ("add_time_ids", "negative_add_time_ids"),
|
||||
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"time_ids": (
|
||||
getattr(block_state, "add_time_ids", None),
|
||||
getattr(block_state, "negative_add_time_ids", None),
|
||||
),
|
||||
"text_embeds": (
|
||||
getattr(block_state, "pooled_prompt_embeds", None),
|
||||
getattr(block_state, "negative_pooled_prompt_embeds", None),
|
||||
),
|
||||
"image_embeds": (
|
||||
getattr(block_state, "ip_adapter_embeds", None),
|
||||
getattr(block_state, "negative_ip_adapter_embeds", None),
|
||||
),
|
||||
}
|
||||
|
||||
# cond_scale for the timestep (controlnet input)
|
||||
@@ -369,12 +395,15 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
# guided denoiser step
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
|
||||
@@ -94,25 +94,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
}
|
||||
transformer_dtype = components.transformer.dtype
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
|
||||
# Predict the noise residual
|
||||
|
||||
@@ -128,6 +128,7 @@ else:
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -241,6 +242,7 @@ else:
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
]
|
||||
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
@@ -306,6 +308,7 @@ else:
|
||||
"SanaSprintPipeline",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaVideoPipeline",
|
||||
]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
@@ -561,6 +564,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
@@ -640,6 +644,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
@@ -731,7 +736,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
|
||||
from .sana import (
|
||||
SanaControlNetPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
)
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_bria_fibo import BriaFiboPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,838 @@
|
||||
# Copyright (c) Bria.ai. All rights reserved.
|
||||
#
|
||||
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
||||
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
||||
#
|
||||
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
||||
# indicate if changes were made, and do not use the material for commercial purposes.
|
||||
#
|
||||
# See the license for further details.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
|
||||
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
|
||||
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from diffusers import BriaFiboPipeline
|
||||
from diffusers.modular_pipelines import ModularPipeline
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
|
||||
|
||||
pipe = BriaFiboPipeline.from_pretrained(
|
||||
"briaai/FIBO",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
with torch.inference_mode():
|
||||
# 1. Create a prompt to generate an initial image
|
||||
output = vlm_pipe(prompt="a beautiful dog")
|
||||
json_prompt_generate = output.values["json_prompt"]
|
||||
|
||||
# Generate the image from the structured json prompt
|
||||
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
|
||||
results_generate.images[0].save("image_generate.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BriaFiboPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer (`BriaFiboTransformer2DModel`):
|
||||
The transformer model for 2D diffusion modeling.
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
|
||||
Scheduler to be used with `transformer` to denoise the encoded latents.
|
||||
vae (`AutoencoderKLWan`):
|
||||
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
|
||||
text_encoder (`SmolLM3ForCausalLM`):
|
||||
Text encoder for processing input prompts.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
Tokenizer used for processing the input text prompts for the text_encoder.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: BriaFiboTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKLWan,
|
||||
text_encoder: SmolLM3ForCausalLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.default_sample_size = 64
|
||||
|
||||
def get_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 2048,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if not prompt:
|
||||
raise ValueError("`prompt` must be a non-empty string or list of strings.")
|
||||
|
||||
batch_size = len(prompt)
|
||||
bot_token_id = 128000
|
||||
|
||||
text_encoder_device = device if device is not None else torch.device("cpu")
|
||||
if not isinstance(text_encoder_device, torch.device):
|
||||
text_encoder_device = torch.device(text_encoder_device)
|
||||
|
||||
if all(p == "" for p in prompt):
|
||||
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
tokenized = self.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = tokenized.input_ids.to(text_encoder_device)
|
||||
attention_mask = tokenized.attention_mask.to(text_encoder_device)
|
||||
|
||||
if any(p == "" for p in prompt):
|
||||
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
|
||||
input_ids[empty_rows] = bot_token_id
|
||||
attention_mask[empty_rows] = 1
|
||||
|
||||
encoder_outputs = self.text_encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_outputs.hidden_states
|
||||
|
||||
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
|
||||
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
hidden_states = tuple(
|
||||
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
|
||||
)
|
||||
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
|
||||
|
||||
return prompt_embeds, hidden_states, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
|
||||
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
|
||||
batch_size, seq_len, dim = prompt_embeds.shape
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
|
||||
if max_tokens < seq_len:
|
||||
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
|
||||
|
||||
if max_tokens > seq_len:
|
||||
pad_length = max_tokens - seq_len
|
||||
padding = torch.zeros(
|
||||
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
|
||||
|
||||
mask_padding = torch.zeros(
|
||||
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
|
||||
|
||||
return prompt_embeds, attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 3000,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
guidance_scale (`float`):
|
||||
Guidance scale for classifier free guidance.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
prompt_attention_mask = None
|
||||
negative_prompt_attention_mask = None
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
|
||||
|
||||
if guidance_scale > 1:
|
||||
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
|
||||
negative_prompt = ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
# Pad to longest
|
||||
if prompt_attention_mask is not None:
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
if negative_prompt_attention_mask is not None:
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
|
||||
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
|
||||
)
|
||||
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
|
||||
|
||||
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
||||
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
||||
)
|
||||
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
|
||||
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
|
||||
else:
|
||||
max_tokens = prompt_embeds.shape[1]
|
||||
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
||||
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
||||
)
|
||||
negative_prompt_layers = None
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt_layers,
|
||||
negative_prompt_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@staticmethod
|
||||
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels)
|
||||
latents = latents.permute(0, 3, 1, 2)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.permute(0, 2, 3, 1)
|
||||
latents = latents.reshape(batch_size, height * width, num_channels_latents)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
do_patching=False,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if do_patching:
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
else:
|
||||
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@staticmethod
|
||||
def _prepare_attention_mask(attention_mask):
|
||||
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
|
||||
|
||||
# convert to 0 - keep, -inf ignore
|
||||
attention_matrix = torch.where(
|
||||
attention_matrix == 1, 0.0, -torch.inf
|
||||
) # Apply -inf to ignored tokens for nulling softmax score
|
||||
return attention_matrix
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 3000,
|
||||
do_patching=False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
|
||||
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt_layers,
|
||||
negative_prompt_layers,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if guidance_scale > 1:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_layers = [
|
||||
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
|
||||
]
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
|
||||
self.transformer.single_transformer_blocks
|
||||
)
|
||||
if len(prompt_layers) >= total_num_layers_transformer:
|
||||
# remove first layers
|
||||
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
|
||||
else:
|
||||
# duplicate last layer
|
||||
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
|
||||
|
||||
# 5. Prepare latent variables
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
if do_patching:
|
||||
num_channels_latents = int(num_channels_latents / 4)
|
||||
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
prompt_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
do_patching,
|
||||
)
|
||||
|
||||
latent_attention_mask = torch.ones(
|
||||
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
|
||||
)
|
||||
if guidance_scale > 1:
|
||||
latent_attention_mask = latent_attention_mask.repeat(2, 1)
|
||||
|
||||
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
|
||||
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
|
||||
|
||||
if self._joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
self._joint_attention_kwargs["attention_mask"] = attention_mask
|
||||
|
||||
# Adapt scheduler to dynamic shifting (resolution dependent)
|
||||
|
||||
if do_patching:
|
||||
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
|
||||
else:
|
||||
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
mu = calculate_shift(
|
||||
seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
|
||||
# Init sigmas and timesteps according to shift size
|
||||
# This changes the scheduler in-place according to the dynamic scheduling
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# Support old different diffusers versions
|
||||
if len(latent_image_ids.shape) == 3:
|
||||
latent_image_ids = latent_image_ids[0]
|
||||
|
||||
if len(text_ids.shape) == 3:
|
||||
text_ids = text_ids[0]
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(
|
||||
device=latent_model_input.device, dtype=latent_model_input.dtype
|
||||
)
|
||||
|
||||
# This is predicts "v" from flow-matching or eps from diffusion
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
text_encoder_layers=prompt_layers,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale > 1:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
if do_patching:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
else:
|
||||
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
|
||||
|
||||
latents = latents.unsqueeze(dim=2)
|
||||
latents_device = latents[0].device
|
||||
latents_dtype = latents[0].dtype
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents_device, latents_dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents_device, latents_dtype
|
||||
)
|
||||
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
|
||||
latents_scaled = torch.cat(latents_scaled, dim=0)
|
||||
image = []
|
||||
for scaled_latent in latents_scaled:
|
||||
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
|
||||
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
|
||||
image.append(curr_image)
|
||||
if len(image) == 1:
|
||||
image = image[0]
|
||||
else:
|
||||
image = np.stack(image, axis=0)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return BriaFiboPipelineOutput(images=image)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 3000:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class BriaFiboPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for BriaFibo pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> model_id = "lodestones/Chroma1-HD"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = ChromaPipeline.from_pretrained(
|
||||
... model_id,
|
||||
@@ -158,7 +158,7 @@ class ChromaPipeline(
|
||||
r"""
|
||||
The Chroma pipeline for text-to-image generation.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma/
|
||||
Reference: https://huggingface.co/lodestones/Chroma1-HD/
|
||||
|
||||
Args:
|
||||
transformer ([`ChromaTransformer2DModel`]):
|
||||
@@ -233,20 +233,23 @@ class ChromaPipeline(
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.clone()
|
||||
tokenizer_mask = text_inputs.attention_mask
|
||||
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
|
||||
tokenizer_mask_device = tokenizer_mask.to(device)
|
||||
|
||||
# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
attention_mask=tokenizer_mask_device,
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(device=device)
|
||||
|
||||
# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
|
||||
seq_lengths = tokenizer_mask_device.sum(dim=1)
|
||||
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> model_id = "lodestones/Chroma1-HD"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
|
||||
r"""
|
||||
The Chroma pipeline for image-to-image generation.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma/
|
||||
Reference: https://huggingface.co/lodestones/Chroma1-HD/
|
||||
|
||||
Args:
|
||||
transformer ([`ChromaTransformer2DModel`]):
|
||||
@@ -247,20 +247,21 @@ class ChromaImg2ImgPipeline(
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.clone()
|
||||
tokenizer_mask = text_inputs.attention_mask
|
||||
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
||||
tokenizer_mask_device = tokenizer_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
attention_mask=tokenizer_mask_device,
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
||||
|
||||
seq_lengths = tokenizer_mask_device.sum(dim=1)
|
||||
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline(
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
+2
-2
@@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
|
||||
_import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_hunyuanimage import HunyuanImagePipeline
|
||||
from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,866 @@
|
||||
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
|
||||
|
||||
from ...guiders import AdaptiveProjectedMixGuidance
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HunyuanImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanImagePipeline
|
||||
|
||||
>>> pipe = HunyuanImagePipeline.from_pretrained(
|
||||
... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
|
||||
>>> image.save("hunyuanimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def extract_glyph_text(prompt: str):
|
||||
"""
|
||||
Extract text enclosed in quotes for glyph rendering.
|
||||
|
||||
Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
|
||||
|
||||
Args:
|
||||
prompt: Input text prompt
|
||||
|
||||
Returns:
|
||||
Formatted glyph text string or None if no quoted text found
|
||||
"""
|
||||
text_prompt_texts = []
|
||||
pattern_quote_single = r"\'(.*?)\'"
|
||||
pattern_quote_double = r"\"(.*?)\""
|
||||
pattern_quote_chinese_single = r"‘(.*?)’"
|
||||
pattern_quote_chinese_double = r"“(.*?)”"
|
||||
|
||||
matches_quote_single = re.findall(pattern_quote_single, prompt)
|
||||
matches_quote_double = re.findall(pattern_quote_double, prompt)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_single)
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
if text_prompt_texts:
|
||||
glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
|
||||
else:
|
||||
glyph_text_formatted = None
|
||||
|
||||
return glyph_text_formatted
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
The HunyuanImage pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
transformer ([`HunyuanImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanImage`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
|
||||
variant.
|
||||
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
|
||||
guider ([`AdaptiveProjectedMixGuidance`]):
|
||||
[AdaptiveProjectedMixGuidance]to be used to guide the image generation.
|
||||
ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
|
||||
[AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_optional_components = ["ocr_guider", "guider"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLHunyuanImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: ByT5Tokenizer,
|
||||
transformer: HunyuanImageTransformer2DModel,
|
||||
guider: Optional[AdaptiveProjectedMixGuidance] = None,
|
||||
ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
guider=guider,
|
||||
ocr_guider=ocr_guider,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = 1000
|
||||
self.tokenizer_2_max_length = 128
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.prompt_template_encode_start_idx = 34
|
||||
self.default_sample_size = 64
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tokenizer_max_length: int = 1000,
|
||||
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
|
||||
drop_idx: int = 34,
|
||||
hidden_state_skip_layer: int = 2,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(
|
||||
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
encoder_hidden_states = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
|
||||
prompt_embeds = prompt_embeds[:, drop_idx:]
|
||||
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
encoder_attention_mask = encoder_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def _get_byt5_prompt_embeds(
|
||||
self,
|
||||
tokenizer: ByT5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: str,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tokenizer_max_length: int = 128,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
if isinstance(prompt, list):
|
||||
raise ValueError("byt5 prompt should be a string")
|
||||
elif prompt is None:
|
||||
raise ValueError("byt5 prompt should not be None")
|
||||
|
||||
txt_tokens = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask.float(),
|
||||
)[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
batch_size: int = 1,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
batch_size (`int`):
|
||||
batch size of prompts, defaults to 1
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = [""] * batch_size
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_max_length,
|
||||
template=self.prompt_template_encode,
|
||||
drop_idx=self.prompt_template_encode_start_idx,
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is None:
|
||||
prompt_embeds_2_list = []
|
||||
prompt_embeds_mask_2_list = []
|
||||
|
||||
glyph_texts = [extract_glyph_text(p) for p in prompt]
|
||||
for glyph_text in glyph_texts:
|
||||
if glyph_text is None:
|
||||
glyph_text_embeds = torch.zeros(
|
||||
(1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
|
||||
)
|
||||
glyph_text_embeds_mask = torch.zeros(
|
||||
(1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
|
||||
tokenizer=self.tokenizer_2,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prompt=glyph_text,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_2_max_length,
|
||||
)
|
||||
|
||||
prompt_embeds_2_list.append(glyph_text_embeds)
|
||||
prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
|
||||
|
||||
prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
|
||||
prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
_, seq_len_2, _ = prompt_embeds_2.shape
|
||||
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
prompt_embeds_2=None,
|
||||
prompt_embeds_mask_2=None,
|
||||
negative_prompt_embeds_2=None,
|
||||
negative_prompt_embeds_mask_2=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt is None and prompt_embeds_2 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
|
||||
)
|
||||
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
distilled_guidance_scale: Optional[float] = 3.25,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
|
||||
not provided, will use an empty negative prompt. Ignored when not using guidance. ).
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
distilled_guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
|
||||
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
|
||||
ignored.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, text embeddings mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
|
||||
input argument.
|
||||
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
|
||||
input argument.
|
||||
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
|
||||
`negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
negative_prompt_embeds_2=negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare prompt embeds
|
||||
|
||||
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
||||
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
|
||||
|
||||
# select guider
|
||||
if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
|
||||
# prompt contains ocr and pipeline has a guider for ocr
|
||||
guider = self.ocr_guider
|
||||
elif self.guider is not None:
|
||||
guider = self.guider
|
||||
# distilled model does not use guidance method, use default guider with enabled=False
|
||||
else:
|
||||
guider = AdaptiveProjectedMixGuidance(enabled=False)
|
||||
|
||||
if guider._enabled and guider.num_conditions > 1:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds_2=negative_prompt_embeds_2,
|
||||
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
|
||||
negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance (for guidance-distilled model)
|
||||
if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
|
||||
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
|
||||
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = (
|
||||
torch.tensor(
|
||||
[distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
|
||||
)
|
||||
* 1000.0
|
||||
)
|
||||
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if self.transformer.config.use_meanflow:
|
||||
if i == len(timesteps) - 1:
|
||||
timestep_r = torch.tensor([0.0], device=device)
|
||||
else:
|
||||
timestep_r = timesteps[i + 1]
|
||||
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
|
||||
else:
|
||||
timestep_r = None
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
|
||||
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
|
||||
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
|
||||
}
|
||||
|
||||
# Step 2: Update guider's internal state for this denoising step
|
||||
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
|
||||
# Step 3: Prepare batched model inputs based on the guidance method
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = guider.prepare_inputs(guider_inputs)
|
||||
# Step 4: Run the denoiser for each batch
|
||||
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
||||
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
||||
for guider_state_batch in guider_state:
|
||||
guider.prepare_models(self.transformer)
|
||||
|
||||
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
||||
cond_kwargs = {
|
||||
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
||||
}
|
||||
|
||||
# e.g. "pred_cond"/"pred_uncond"
|
||||
context_name = getattr(guider_state_batch, guider._identifier_key)
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep_r,
|
||||
guidance=guidance,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
|
||||
# Cleanup model (e.g., remove hooks)
|
||||
guider.cleanup_models(self.transformer)
|
||||
|
||||
# Step 5: Combine predictions using the guidance method
|
||||
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
||||
# Continuing the CFG example, the guider receives:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
||||
# ]
|
||||
# And extracts predictions using the __guidance_identifier__:
|
||||
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
||||
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
||||
# Then applies CFG formula:
|
||||
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
noise_pred = guider(guider_state)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return HunyuanImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,752 @@
|
||||
# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from ...guiders import AdaptiveProjectedMixGuidance
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HunyuanImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanImageRefinerPipeline
|
||||
|
||||
>>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
|
||||
... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = load_image("path/to/image.png")
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
|
||||
>>> image.save("hunyuanimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class HunyuanImageRefinerPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
The HunyuanImage pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
transformer ([`HunyuanImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLHunyuanImageRefiner`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_optional_components = ["guider"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLHunyuanImageRefiner,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
transformer: HunyuanImageTransformer2DModel,
|
||||
guider: Optional[AdaptiveProjectedMixGuidance] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
guider=guider,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = 256
|
||||
self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
self.prompt_template_encode_start_idx = 36
|
||||
self.default_sample_size = 64
|
||||
self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
|
||||
|
||||
# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tokenizer_max_length: int = 1000,
|
||||
template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
|
||||
drop_idx: int = 34,
|
||||
hidden_state_skip_layer: int = 2,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(
|
||||
txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
encoder_hidden_states = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
|
||||
prompt_embeds = prompt_embeds[:, drop_idx:]
|
||||
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
encoder_attention_mask = encoder_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
batch_size: int = 1,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
batch_size (`int`):
|
||||
batch size of prompts, defaults to 1
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = [""] * batch_size
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_max_length,
|
||||
template=self.prompt_template_encode,
|
||||
drop_idx=self.prompt_template_encode_start_idx,
|
||||
)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image_latents,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
strength=0.25,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, 1, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
cond_latents = strength * noise + (1 - strength) * image_latents
|
||||
|
||||
return latents, cond_latents
|
||||
|
||||
@staticmethod
|
||||
def _reorder_image_tokens(image_latents):
|
||||
image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
|
||||
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
||||
image_latents = image_latents.reshape(
|
||||
batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
|
||||
)
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return image_latents
|
||||
|
||||
@staticmethod
|
||||
def _restore_image_tokens_order(latents):
|
||||
"""Restore image tokens order by splitting channels and removing first frame slice."""
|
||||
batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
|
||||
latents = latents.reshape(
|
||||
batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
|
||||
) # B, F*2, C//2, H, W
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
|
||||
# Remove first frame slice
|
||||
latents = latents[:, :, 1:]
|
||||
|
||||
return latents
|
||||
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
|
||||
image_latents = self._reorder_image_tokens(image_latents)
|
||||
|
||||
image_latents = image_latents * self.vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
distilled_guidance_scale: Optional[float] = 3.25,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 4,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
|
||||
prompt. Ignored when not using guidance.
|
||||
distilled_guidance_scale (`float`, *optional*, defaults to None):
|
||||
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
||||
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
||||
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
||||
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
|
||||
guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
|
||||
ignored.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. process image
|
||||
if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
|
||||
image_latents = image
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height, width)
|
||||
image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
|
||||
# 3.prepare prompt embeds
|
||||
|
||||
if self.guider is not None:
|
||||
guider = self.guider
|
||||
else:
|
||||
# distilled model does not use guidance method, use default guider with enabled=False
|
||||
guider = AdaptiveProjectedMixGuidance(enabled=False)
|
||||
|
||||
requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
||||
|
||||
if requires_unconditional_embeds:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latents, cond_latents = self.prepare_latents(
|
||||
image_latents=image_latents,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=self.latent_channels,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance (this pipeline only supports guidance-distilled models)
|
||||
if distilled_guidance_scale is None:
|
||||
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
|
||||
guidance = (
|
||||
torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
|
||||
* 1000.0
|
||||
)
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
|
||||
}
|
||||
|
||||
# Step 2: Update guider's internal state for this denoising step
|
||||
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
|
||||
# Step 3: Prepare batched model inputs based on the guidance method
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# Step 4: Run the denoiser for each batch
|
||||
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
||||
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
||||
for guider_state_batch in guider_state:
|
||||
guider.prepare_models(self.transformer)
|
||||
|
||||
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
||||
cond_kwargs = {
|
||||
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
||||
}
|
||||
|
||||
# e.g. "pred_cond"/"pred_uncond"
|
||||
context_name = getattr(guider_state_batch, guider._identifier_key)
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
|
||||
# Cleanup model (e.g., remove hooks)
|
||||
guider.cleanup_models(self.transformer)
|
||||
|
||||
# Step 5: Combine predictions using the guidance method
|
||||
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
||||
# Continuing the CFG example, the guider receives:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
||||
# ]
|
||||
# And extracts predictions using the __guidance_identifier__:
|
||||
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
||||
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
||||
# Then applies CFG formula:
|
||||
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
noise_pred = guider(guider_state)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
latents = self._restore_image_tokens_order(latents)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return HunyuanImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for HunyuanImage pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -113,7 +113,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
_cut_context=False,
|
||||
_cut_context=True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
@@ -173,8 +173,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
)
|
||||
self.prompt_template_encode_start_idx = 129
|
||||
|
||||
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
)
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
@staticmethod
|
||||
@@ -384,6 +386,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
@@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -33,6 +33,7 @@ from ..utils import (
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_maybe_remap_transformers_class,
|
||||
deprecate,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
@@ -75,6 +76,7 @@ LOADABLE_CLASSES = {
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
"BaseGuidance": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
@@ -356,6 +358,11 @@ def maybe_raise_or_warn(
|
||||
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
# Handle deprecated Transformers classes
|
||||
if library_name == "transformers":
|
||||
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
@@ -390,6 +397,11 @@ def simple_get_class_obj(library_name, class_name):
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
# Handle deprecated Transformers classes
|
||||
if library_name == "transformers":
|
||||
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
return class_obj
|
||||
@@ -416,6 +428,10 @@ def get_class_obj_and_candidates(
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
# Handle deprecated Transformers classes
|
||||
if library_name == "transformers":
|
||||
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
|
||||
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
|
||||
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
|
||||
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_sana_controlnet import SanaControlNetPipeline
|
||||
from .pipeline_sana_sprint import SanaSprintPipeline
|
||||
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
|
||||
from .pipeline_sana_video import SanaVideoPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
@@ -19,3 +20,18 @@ class SanaPipelineOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaVideoPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Sana-Video pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,7 @@ from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
class StableCascadeDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating images from the Stable Cascade model.
|
||||
|
||||
@@ -79,6 +79,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
width=int(24*10.67)=256 in order to match the training conditions.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.35.2"
|
||||
|
||||
unet_name = "decoder"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
|
||||
|
||||
@@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import is_torch_version, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
@@ -42,7 +42,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
class StableCascadeCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Combined Pipeline for text-to-image generation using Stable Cascade.
|
||||
|
||||
@@ -74,6 +74,8 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.35.2"
|
||||
|
||||
_load_connected_pipes = True
|
||||
_optional_components = ["prior_feature_extractor", "prior_image_encoder"]
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -77,7 +77,7 @@ class StableCascadePriorPipelineOutput(BaseOutput):
|
||||
negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
class StableCascadePriorPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating image prior for Stable Cascade.
|
||||
|
||||
@@ -103,6 +103,8 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
Default resolution for multiple images generated.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.35.2"
|
||||
|
||||
unet_name = "prior"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "image_encoder->text_encoder->prior"
|
||||
|
||||
@@ -248,7 +248,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
@@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanVACETransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
transformer ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
|
||||
`transformer` or `transformer_2` must be provided.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
|
||||
`transformer` or `transformer_2` must be provided.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: WanVACETransformer3DModel = None,
|
||||
transformer_2: WanVACETransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
reference_images=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
elif self.transformer_2 is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
|
||||
)
|
||||
|
||||
if height % base != 0 or width % base != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
|
||||
|
||||
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
base = self.vae_scale_factor_spatial * (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
transformer_patch_size = (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
|
||||
mask_list = []
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
|
||||
vace_layers = (
|
||||
self.transformer.config.vace_layers
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.vace_layers
|
||||
)
|
||||
if isinstance(conditioning_scale, (int, float)):
|
||||
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
|
||||
conditioning_scale = [conditioning_scale] * len(vace_layers)
|
||||
if isinstance(conditioning_scale, list):
|
||||
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
|
||||
if len(conditioning_scale) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = torch.tensor(conditioning_scale)
|
||||
if isinstance(conditioning_scale, torch.Tensor):
|
||||
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
|
||||
if conditioning_scale.size(0) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
|
||||
|
||||
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = (
|
||||
self.transformer.config.in_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.in_channels
|
||||
)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -38,7 +38,7 @@ from .constants import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from .deprecation_utils import deprecate
|
||||
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
|
||||
from .doc_utils import replace_example_docstring
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
|
||||
@@ -64,6 +64,8 @@ from .import_utils import (
|
||||
get_objects_from_module,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_better_profanity_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
|
||||
@@ -45,7 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
||||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
|
||||
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
|
||||
@@ -4,6 +4,54 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Mapping for deprecated Transformers classes to their replacements
|
||||
# This is used to handle models that reference deprecated class names in their configs
|
||||
# Reference: https://github.com/huggingface/transformers/issues/40822
|
||||
# Format: {
|
||||
# "DeprecatedClassName": {
|
||||
# "new_class": "NewClassName",
|
||||
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
|
||||
# }
|
||||
# }
|
||||
_TRANSFORMERS_CLASS_REMAPPING = {
|
||||
"CLIPFeatureExtractor": {
|
||||
"new_class": "CLIPImageProcessor",
|
||||
"transformers_version": (">", "4.57.0"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
|
||||
"""
|
||||
Check if a Transformers class should be remapped to a newer version.
|
||||
|
||||
Args:
|
||||
class_name: The name of the class to check
|
||||
|
||||
Returns:
|
||||
The new class name if remapping should occur, None otherwise
|
||||
"""
|
||||
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
|
||||
return None
|
||||
|
||||
from .import_utils import is_transformers_version
|
||||
|
||||
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
|
||||
operation, required_version = mapping["transformers_version"]
|
||||
|
||||
# Only remap if the transformers version meets the requirement
|
||||
if is_transformers_version(operation, required_version):
|
||||
new_class = mapping["new_class"]
|
||||
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
|
||||
return mapping["new_class"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
|
||||
from .. import __version__
|
||||
|
||||
@@ -17,6 +17,21 @@ class AdaptiveProjectedGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AdaptiveProjectedMixGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -32,6 +47,21 @@ class AutoGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BaseGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -378,6 +408,36 @@ class AutoencoderKLCosmos(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -528,6 +588,21 @@ class AutoModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BriaFiboTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BriaTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -858,6 +933,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1218,6 +1308,21 @@ class SanaTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SD3ControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user