Compare commits
73 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ce43e58afe | |||
| 51fbe6a1ed | |||
| d456b5d925 | |||
| b9be438a2a | |||
| e3d6945bcf | |||
| 01f488e6d9 | |||
| d8247cfbb3 | |||
| 57cf1b4134 | |||
| 8094f660cf | |||
| 15e3a0fe60 | |||
| 6bf668c4d2 | |||
| c6fc91031e | |||
| ffc95627a9 | |||
| b3b11b5fc2 | |||
| 511c7a4c40 | |||
| 1579a83d9a | |||
| 64eaf85403 | |||
| 6ac18d00e7 | |||
| 5f8a9b61b5 | |||
| abace0508f | |||
| 2d4f144da4 | |||
| 41c59213da | |||
| 755dc49d1b | |||
| 4bd7dd56b8 | |||
| 89ebea4fbd | |||
| b1a883517d | |||
| e6d4612309 | |||
| a88a7b4f03 | |||
| c8656ed73c | |||
| 94c9613f99 | |||
| b91e8c0d0b | |||
| ac7864624b | |||
| 5ffb73d4ae | |||
| 4088e8a851 | |||
| d33d9f6715 | |||
| dde8754ba2 | |||
| fbcd3ba6b2 | |||
| d176f61fcf | |||
| d13e6c08fd | |||
| 6c0d55de20 | |||
| 354d35adb0 | |||
| db38c47807 | |||
| 4839692df2 | |||
| 54adb215a0 | |||
| c8176bfe04 | |||
| 9322997b24 | |||
| ef67154217 | |||
| 7c9dc971ac | |||
| c12a61f216 | |||
| debafc6960 | |||
| 8048623daf | |||
| 72fc6ad797 | |||
| 544ba677dd | |||
| 6f1042e36c | |||
| 35dd13c5a4 | |||
| 8ca0fa8ea4 | |||
| 8832deef93 | |||
| 8885a13c9a | |||
| c5e9a4a648 | |||
| fc87f40e7a | |||
| 9fecdc973b | |||
| a2b7de3611 | |||
| 5b3295ad48 | |||
| 78f292ea77 | |||
| d684d4647f | |||
| b2f0ff7454 | |||
| 435a8c02af | |||
| d7ef6a0104 | |||
| c9a9559600 | |||
| 9df6c2f580 | |||
| 13cf2b0c28 | |||
| 369847397e | |||
| d47d60f3e6 |
@@ -349,6 +349,8 @@
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/flux2_transformer
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
@@ -525,6 +527,8 @@
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/flux2
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hidream
|
||||
|
||||
@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
|
||||
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
> [!TIP]
|
||||
@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## Flux2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Flux2Transformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
|
||||
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Flux2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
|
||||
|
||||
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!TIP]
|
||||
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
|
||||
>
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## Flux2Pipeline
|
||||
|
||||
[[autodoc]] Flux2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -43,11 +43,13 @@ Note: The recommended dtype mentioned is for the transformer weights. The text e
|
||||
<hfoptions id="generation pipelines">`
|
||||
<hfoption id="Text-to-Video">
|
||||
|
||||
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
|
||||
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.
|
||||
|
||||
```python
|
||||
model_id =
|
||||
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.to("cuda")
|
||||
@@ -75,12 +77,11 @@ export_to_video(video, "sana_video.mp4", fps=16)
|
||||
</hfoption>
|
||||
<hfoption id="Image-to-Video">
|
||||
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
|
||||
|
||||
```python
|
||||
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
|
||||
pipe = SanaImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
||||
|
||||
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: 🧨 Diffusers
|
||||
- local: quicktour
|
||||
title: Tour rápido
|
||||
- local: installation
|
||||
title: Instalação
|
||||
- local: index
|
||||
title: Diffusers
|
||||
- local: installation
|
||||
title: Instalação
|
||||
- local: quicktour
|
||||
title: Tour rápido
|
||||
- local: stable_diffusion
|
||||
title: Desempenho básico
|
||||
title: Primeiros passos
|
||||
|
||||
@@ -18,11 +18,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Diffusers
|
||||
|
||||
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou queira treinar seu próprio modelo de difusão, 🤗 Diffusers é uma modular caixa de ferramentas que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou quer treinar seu próprio modelo de difusão, 🤗 Diffusers é uma caixa de ferramentas modular que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||
|
||||
A Biblioteca tem três componentes principais:
|
||||
|
||||
- Pipelines de última geração para a geração em poucas linhas de código. Têm muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
|
||||
- Pipelines de última geração para a geração em poucas linhas de código. Há muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
|
||||
- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.
|
||||
- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Recomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).
|
||||
Se você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
Um ambiente virtual deixa mais fácil gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
|
||||
Um ambiente virtual facilita gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
|
||||
|
||||
Comece criando um ambiente virtual no diretório do projeto:
|
||||
|
||||
@@ -100,12 +100,12 @@ pip install -e ".[flax]"
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
|
||||
Esses comandos irão vincular a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
|
||||
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
|
||||
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
|
||||
|
||||
> [!WARNING]
|
||||
> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
|
||||
> Você deve manter a pasta `diffusers` se quiser continuar usando a biblioteca.
|
||||
|
||||
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
|
||||
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Desempenho básico
|
||||
|
||||
Difusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.
|
||||
|
||||
Este guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.
|
||||
|
||||
## Uso de memória
|
||||
|
||||
Reduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.
|
||||
|
||||
O método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
pipeline(prompt).images[0]
|
||||
print(f"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
```
|
||||
|
||||
## Velocidade de inferência
|
||||
|
||||
O processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.
|
||||
|
||||
- Adicione `device_map="cuda"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.
|
||||
- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import time
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.
|
||||
- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.
|
||||
|
||||
```py
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
image = pipeline(prompt).images[0]
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"Geração de imagem levou {end_time - start_time:.3f} segundos")
|
||||
```
|
||||
|
||||
## Qualidade de geração
|
||||
|
||||
Muitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.
|
||||
|
||||
- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
negative_prompt = "low quality, blurry, ugly, poor details"
|
||||
pipeline(prompt, negative_prompt=negative_prompt).images[0]
|
||||
```
|
||||
|
||||
Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).
|
||||
|
||||
- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, HeunDiscreteScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
negative_prompt = "low quality, blurry, ugly, poor details"
|
||||
pipeline(prompt, negative_prompt=negative_prompt).images[0]
|
||||
```
|
||||
|
||||
## Próximos passos
|
||||
|
||||
Diffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.
|
||||
@@ -489,7 +489,6 @@ class AdaptiveMaskInpaintPipeline(
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -651,7 +650,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -666,7 +665,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -1380,7 +1380,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
flow_model.eval()
|
||||
self.flow_model = flow_model
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -1413,7 +1413,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -1672,7 +1672,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -1687,7 +1687,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
@@ -1699,7 +1699,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -277,7 +277,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -459,7 +459,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -525,7 +525,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -1195,7 +1195,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
# Below are methods copied from StableDiffusionPipeline
|
||||
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -1228,7 +1228,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -1426,7 +1426,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -1441,7 +1441,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
@@ -1453,7 +1453,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
@@ -1534,17 +1534,17 @@ class LLMGroundedDiffusionPipeline(
|
||||
return emb
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.guidance_scale
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.guidance_rescale
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.clip_skip
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@@ -1552,16 +1552,16 @@ class LLMGroundedDiffusionPipeline(
|
||||
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.do_classifier_free_guidance
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.cross_attention_kwargs
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.num_timesteps
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@@ -503,7 +503,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -518,7 +518,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -532,7 +532,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -721,7 +721,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -941,7 +941,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -121,7 +121,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -136,7 +136,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -196,7 +196,7 @@ def _get_crops_coords_list(num_rows, num_cols, output_width):
|
||||
return result_list
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
@@ -223,7 +223,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -627,7 +627,7 @@ class StableDiffusionXLTilingPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -244,7 +244,7 @@ def _tile2latent_indices(
|
||||
return latent_row_init, latent_row_end, latent_col_init, latent_col_end
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -394,7 +394,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
|
||||
COSINE = "Cosine"
|
||||
GAUSSIAN = "Gaussian"
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -633,7 +633,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -194,7 +194,7 @@ class AnimateDiffControlNetPipeline(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -460,7 +460,7 @@ class AnimateDiffControlNetPipeline(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -180,7 +180,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -194,7 +194,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -311,7 +311,7 @@ class AnimateDiffImgToVideoPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -577,7 +577,7 @@ class AnimateDiffImgToVideoPipeline(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -165,7 +165,7 @@ class AnimateDiffPipelineIpex(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -438,7 +438,7 @@ class AnimateDiffPipelineIpex(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -113,7 +113,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -458,7 +458,7 @@ class KolorsControlNetPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -133,7 +133,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -501,7 +501,7 @@ class KolorsControlNetImg2ImgPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for others.
|
||||
|
||||
@@ -120,7 +120,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -134,7 +134,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -552,7 +552,7 @@ class KolorsControlNetInpaintPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -73,7 +73,7 @@ def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
||||
return blurred_latents
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -379,7 +379,7 @@ class DemoFusionSDXLPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -186,7 +186,7 @@ class FabricPipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
@@ -1078,7 +1078,7 @@ class LocalAttention:
|
||||
return out
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -1101,7 +1101,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -1125,7 +1125,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -1505,7 +1505,7 @@ class FaithDiffStableDiffusionXLPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -95,7 +95,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -109,7 +109,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -502,7 +502,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
@@ -517,7 +517,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
|
||||
@@ -126,7 +126,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -186,7 +186,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -567,7 +567,7 @@ class FluxKontextPipeline(
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
@@ -582,7 +582,7 @@ class FluxKontextPipeline(
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
|
||||
@@ -89,7 +89,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -103,7 +103,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -86,7 +86,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -100,7 +100,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -640,7 +640,7 @@ class FluxSemanticGuidancePipeline(
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
@@ -655,7 +655,7 @@ class FluxSemanticGuidancePipeline(
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
|
||||
@@ -65,7 +65,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -79,7 +79,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -146,7 +146,7 @@ def get_resize_crop_region_for_grid(src, tgt_size):
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -161,7 +161,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -177,7 +177,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -512,7 +512,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -527,7 +527,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -80,7 +80,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -458,7 +458,7 @@ class KolorsDifferentialImg2ImgPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
@@ -709,7 +709,7 @@ class KolorsDifferentialImg2ImgPipeline(
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
@@ -94,7 +94,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -243,7 +243,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -257,7 +257,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -669,7 +669,7 @@ class KolorsInpaintPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -213,7 +213,7 @@ class Prompt2PromptPipeline(
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -246,7 +246,7 @@ class Prompt2PromptPipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -430,7 +430,7 @@ class Prompt2PromptPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -445,7 +445,7 @@ class Prompt2PromptPipeline(
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -311,7 +311,7 @@ class SharedAttentionProcessor(AttnProcessor2_0):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -326,7 +326,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -371,7 +371,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -769,7 +769,7 @@ class StyleAlignedSDXLPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -80,7 +80,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -86,7 +86,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -100,7 +100,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -114,7 +114,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -386,7 +386,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -297,7 +297,7 @@ class AAS_XL(AttentionBase):
|
||||
return out
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -439,7 +439,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -453,7 +453,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -696,7 +696,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -931,7 +931,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -137,7 +137,7 @@ def _preprocess_adapter_image(image, height, width):
|
||||
return image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -237,7 +237,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
else 128
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -475,7 +475,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -283,7 +283,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -384,7 +384,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
else 128
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -622,7 +622,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -88,7 +88,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -103,7 +103,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -117,7 +117,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -268,7 +268,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -506,7 +506,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -98,7 +98,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -113,7 +113,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -520,7 +520,7 @@ class StableDiffusionXLPipelineIpex(
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -138,7 +138,7 @@ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -409,7 +409,7 @@ class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -142,7 +142,7 @@ def forward_without_stg(
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -119,7 +119,7 @@ def forward_with_stg(
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -133,7 +133,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -124,7 +124,7 @@ def forward_with_stg(
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -138,7 +138,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -198,7 +198,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -137,7 +137,7 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
return sigma_schedule
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -97,7 +97,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -124,7 +124,7 @@ def retrieve_latents_fill(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -534,7 +534,7 @@ class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin,
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
@@ -549,7 +549,7 @@ class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin,
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.FluxMixin._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
@@ -1168,12 +1168,12 @@ class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin,
|
||||
generator,
|
||||
)
|
||||
|
||||
mask_imagee = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
masked_imagee = init_image * (1 - mask_imagee)
|
||||
masked_imagee = masked_imagee.to(dtype=self.vae.dtype, device=device)
|
||||
maskkk, masked_image_latentsss = self.prepare_mask_latents_fill(
|
||||
mask_imagee,
|
||||
masked_imagee,
|
||||
mask_image_fill = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
masked_image_fill = init_image * (1 - mask_image_fill)
|
||||
masked_image_fill = masked_image_fill.to(dtype=self.vae.dtype, device=device)
|
||||
mask_fill, masked_latents_fill = self.prepare_mask_latents_fill(
|
||||
mask_image_fill,
|
||||
masked_image_fill,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_images_per_prompt,
|
||||
@@ -1243,7 +1243,7 @@ class FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin,
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
masked_image_latents_fill = torch.cat((masked_image_latentsss, maskkk), dim=-1)
|
||||
masked_image_latents_fill = torch.cat((masked_latents_fill, mask_fill), dim=-1)
|
||||
latent_model_input = torch.cat([latents, masked_image_latents_fill], dim=2)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
|
||||
@@ -234,7 +234,7 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
@@ -248,7 +248,7 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -338,7 +338,7 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
@@ -352,7 +352,7 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -358,7 +358,7 @@ class StableDiffusionReferencePipeline(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -408,7 +408,7 @@ class StableDiffusionReferencePipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
@@ -639,7 +639,7 @@ class StableDiffusionReferencePipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(
|
||||
self, generator: Union[torch.Generator, List[torch.Generator]], eta: float
|
||||
) -> Dict[str, Any]:
|
||||
@@ -789,7 +789,7 @@ class StableDiffusionReferencePipeline(
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
return ref_image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(
|
||||
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
|
||||
|
||||
@@ -281,7 +281,7 @@ class StableDiffusionRepaintPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -427,7 +427,7 @@ class StableDiffusionRepaintPipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
@@ -438,7 +438,7 @@ class StableDiffusionRepaintPipeline(
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
@@ -456,7 +456,7 @@ class StableDiffusionRepaintPipeline(
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -832,7 +832,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if "vae_encoder" in self.stages:
|
||||
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(
|
||||
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
|
||||
|
||||
@@ -915,7 +915,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(
|
||||
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
|
||||
|
||||
@@ -788,7 +788,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(
|
||||
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
|
||||
|
||||
@@ -87,7 +87,7 @@ def torch_dfs(model: torch.nn.Module):
|
||||
return result
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -61,7 +61,7 @@ def torch_dfs(model: torch.nn.Module):
|
||||
return result
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -76,7 +76,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
# DreamBooth training example for FLUX.2 [dev]
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
|
||||
|
||||
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
|
||||
|
||||
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
|
||||
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
|
||||
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_flux.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
|
||||
|
||||
## Memory Optimizations
|
||||
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
|
||||
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
|
||||
### Remote Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
|
||||
### QLoRA: Low Precision Training with Quantization
|
||||
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
|
||||
- **FP8 training** with `torchao`:
|
||||
enable FP8 training by passing `--do_fp8_training`.
|
||||
> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
|
||||
> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
|
||||
- **NF4 training** with `bitsandbytes`:
|
||||
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
|
||||
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
|
||||
### Gradient Checkpointing and Accumulation
|
||||
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
|
||||
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
|
||||
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
|
||||
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
|
||||
### 8-bit-Adam Optimizer
|
||||
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
Make sure to install `bitsandbytes` if you want to do so.
|
||||
### Image Resolution
|
||||
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
|
||||
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
|
||||
### Precision of saved LoRA layers
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Prodigy Optimizer
|
||||
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
|
||||
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
|
||||
|
||||
to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
|
||||
```bash
|
||||
--optimizer="prodigy"
|
||||
```
|
||||
> [!TIP]
|
||||
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
|
||||
|
||||
To perform DreamBooth with LoRA, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="prodigy" \
|
||||
--learning_rate=1. \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
|
||||
|
||||
## Training Image-to-Image
|
||||
|
||||
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
**important**
|
||||
|
||||
**Important**
|
||||
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
|
||||
To start, you must have a dataset containing triplets:
|
||||
|
||||
* Condition image - the input image to be transformed.
|
||||
* Target image - the desired output image after transformation.
|
||||
* Instruction - a text prompt describing the transformation from the condition image to the target image.
|
||||
|
||||
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_flux2_img2img.py \
|
||||
--pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
|
||||
--output_dir="flux2-i2i" \
|
||||
--dataset_name="kontext-community/relighting" \
|
||||
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--optimizer="adamw" \
|
||||
--use_8bit_adam \
|
||||
--cache_latents \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=200 \
|
||||
--max_train_steps=1000 \
|
||||
--rank=16\
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
More generally, when performing I2I fine-tuning, we expect you to:
|
||||
|
||||
* Have a dataset `kontext-community/relighting`
|
||||
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
|
||||
|
||||
### Misc notes
|
||||
|
||||
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
|
||||
### Aspect Ratio Bucketing
|
||||
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
|
||||
|
||||
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
|
||||
|
||||
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
|
||||
`
|
||||
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
@@ -0,0 +1,262 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "dog"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
|
||||
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
|
||||
|
||||
def test_dreambooth_lora_flux2(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--checkpointing_steps=2
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1016,7 +1016,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
||||
return new_string[:-nSpace]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -1124,7 +1124,7 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
||||
return new_string[:-nSpace]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -1323,7 +1323,7 @@ class AnyTextPipeline(
|
||||
return True
|
||||
return False
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -1356,7 +1356,7 @@ class AnyTextPipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -1610,7 +1610,7 @@ class AnyTextPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -1625,7 +1625,7 @@ class AnyTextPipeline(
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
@@ -1637,7 +1637,7 @@ class AnyTextPipeline(
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -185,7 +185,7 @@ def get_closest_hw(width, height, image_size):
|
||||
return width, height
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -457,7 +457,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -97,7 +97,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -285,7 +285,7 @@ class PromptDiffusionPipeline(
|
||||
)
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -318,7 +318,7 @@ class PromptDiffusionPipeline(
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -525,7 +525,7 @@ class PromptDiffusionPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -540,7 +540,7 @@ class PromptDiffusionPipeline(
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
@@ -552,7 +552,7 @@ class PromptDiffusionPipeline(
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -0,0 +1,475 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
|
||||
|
||||
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
"""
|
||||
# VAE
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
|
||||
--vae_filename "flux2-vae.sft" \
|
||||
--output_path "/raid/yiyi/dummy-flux2-diffusers" \
|
||||
--vae
|
||||
|
||||
# DiT
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
|
||||
--dit_filename flux-dev-dummy.sft \
|
||||
--dit \
|
||||
--output_path .
|
||||
|
||||
# Full pipe
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
|
||||
--dit_filename flux-dev-dummy.sft \
|
||||
--vae_filename "flux2-vae.sft" \
|
||||
--dit --vae --full_pipe \
|
||||
--output_path .
|
||||
"""
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
|
||||
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
|
||||
parser.add_argument("--vae", action="store_true")
|
||||
parser.add_argument("--dit", action="store_true")
|
||||
parser.add_argument("--vae_dtype", type=str, default="fp32")
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--full_pipe", action="store_true")
|
||||
parser.add_argument("--output_path", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
|
||||
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
||||
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
||||
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
||||
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
||||
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
|
||||
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
|
||||
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
||||
"decoder.conv_in.bias": "decoder.conv_in.bias",
|
||||
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
||||
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
||||
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
|
||||
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
|
||||
"quant_conv.weight": "encoder.quant_conv.weight",
|
||||
"quant_conv.bias": "encoder.quant_conv.bias",
|
||||
"post_quant_conv.weight": "decoder.post_quant_conv.weight",
|
||||
"post_quant_conv.bias": "decoder.post_quant_conv.bias",
|
||||
"bn.running_mean": "bn.running_mean",
|
||||
"bn.running_var": "bn.running_var",
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
||||
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
||||
|
||||
|
||||
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = (
|
||||
ldm_key.replace(mapping["old"], mapping["new"])
|
||||
.replace("norm.weight", "group_norm.weight")
|
||||
.replace("norm.bias", "group_norm.bias")
|
||||
.replace("q.weight", "to_q.weight")
|
||||
.replace("q.bias", "to_q.bias")
|
||||
.replace("k.weight", "to_k.weight")
|
||||
.replace("k.bias", "to_k.bias")
|
||||
.replace("v.weight", "to_v.weight")
|
||||
.replace("v.bias", "to_v.bias")
|
||||
.replace("proj_out.weight", "to_out.0.weight")
|
||||
.replace("proj_out.bias", "to_out.0.bias")
|
||||
)
|
||||
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
shape = new_checkpoint[diffusers_key].shape
|
||||
|
||||
if len(shape) == 3:
|
||||
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
|
||||
elif len(shape) == 4:
|
||||
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
|
||||
|
||||
|
||||
def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
|
||||
if ldm_key not in vae_state_dict:
|
||||
continue
|
||||
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len(config["down_block_types"])
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
|
||||
)
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
|
||||
)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
update_vae_attentions_ldm_to_diffusers(
|
||||
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
)
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len(config["up_block_types"])
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
|
||||
)
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
|
||||
)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
update_vae_attentions_ldm_to_diffusers(
|
||||
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Image and text input projections
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
# Timestep and guidance embeddings
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
|
||||
# Modulation parameters
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
# Final output layer
|
||||
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
|
||||
# Handle fused QKV projections separately as we need to break into Q, K, V projections
|
||||
"img_attn.norm.query_norm": "attn.norm_q",
|
||||
"img_attn.norm.key_norm": "attn.norm_k",
|
||||
"img_attn.proj": "attn.to_out.0",
|
||||
"img_mlp.0": "ff.linear_in",
|
||||
"img_mlp.2": "ff.linear_out",
|
||||
"txt_attn.norm.query_norm": "attn.norm_added_q",
|
||||
"txt_attn.norm.key_norm": "attn.norm_added_k",
|
||||
"txt_attn.proj": "attn.to_add_out",
|
||||
"txt_mlp.0": "ff_context.linear_in",
|
||||
"txt_mlp.2": "ff_context.linear_out",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
|
||||
"linear1": "attn.to_qkv_mlp_proj",
|
||||
"linear2": "attn.to_out",
|
||||
"norm.query_norm": "attn.norm_q",
|
||||
"norm.key_norm": "attn.norm_k",
|
||||
}
|
||||
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
|
||||
# diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight
|
||||
if ".weight" not in key:
|
||||
return
|
||||
|
||||
# If adaLN_modulation is in the key, swap scale and shift parameters
|
||||
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
|
||||
if "adaLN_modulation" in key:
|
||||
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
|
||||
# Assume all such keys are in the AdaLayerNorm key map
|
||||
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
|
||||
new_key = ".".join([new_key_without_param_type, param_type])
|
||||
|
||||
swapped_weight = swap_scale_shift(state_dict.pop(key))
|
||||
state_dict[new_key] = swapped_weight
|
||||
return
|
||||
|
||||
|
||||
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
new_prefix = "transformer_blocks"
|
||||
if "double_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
if "qkv" in within_block_name:
|
||||
fused_qkv_weight = state_dict.pop(key)
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
if "img" in modality_block_name:
|
||||
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.to_q"
|
||||
new_k_name = "attn.to_k"
|
||||
new_v_name = "attn.to_v"
|
||||
elif "txt" in modality_block_name:
|
||||
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.add_q_proj"
|
||||
new_k_name = "attn.add_k_proj"
|
||||
new_v_name = "attn.add_v_proj"
|
||||
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
|
||||
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
|
||||
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
|
||||
state_dict[new_q_key] = to_q_weight
|
||||
state_dict[new_k_key] = to_k_weight
|
||||
state_dict[new_v_key] = to_v_weight
|
||||
else:
|
||||
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
# Mapping:
|
||||
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
|
||||
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
|
||||
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
|
||||
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
|
||||
new_prefix = "single_transformer_blocks"
|
||||
if "single_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaLN_modulation": convert_ada_layer_norm_weights,
|
||||
"double_blocks": convert_flux2_double_stream_blocks,
|
||||
"single_blocks": convert_flux2_single_stream_blocks,
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "test" or model_type == "dummy-flux2":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-flux2",
|
||||
"diffusers_config": {
|
||||
"patch_size": 1,
|
||||
"in_channels": 128,
|
||||
"num_layers": 8,
|
||||
"num_single_layers": 48,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 48,
|
||||
"joint_attention_dim": 15360,
|
||||
"timestep_guidance_channels": 256,
|
||||
"mlp_ratio": 3.0,
|
||||
"axes_dims_rope": (32, 32, 32, 32),
|
||||
"rope_theta": 2000,
|
||||
"eps": 1e-6,
|
||||
},
|
||||
}
|
||||
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
|
||||
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.vae:
|
||||
original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
|
||||
vae = AutoencoderKLFlux2()
|
||||
converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
|
||||
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
||||
if not args.full_pipe:
|
||||
vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
||||
vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
|
||||
|
||||
if args.dit:
|
||||
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
|
||||
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
|
||||
if not args.full_pipe:
|
||||
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
|
||||
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
|
||||
|
||||
if args.full_pipe:
|
||||
tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
||||
generate_config = GenerationConfig.from_pretrained(text_encoder_id)
|
||||
generate_config.do_sample = True
|
||||
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
|
||||
)
|
||||
|
||||
pipe = Flux2Pipeline(
|
||||
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
pipe.save_pretrained(args.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -186,6 +186,7 @@ else:
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLFlux2",
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
@@ -215,6 +216,7 @@ else:
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"Flux2Transformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
@@ -271,6 +273,7 @@ else:
|
||||
"WanAnimateTransformer3DModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"ZImageTransformer2DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
@@ -457,6 +460,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -647,6 +651,7 @@ else:
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -900,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLFlux2,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
@@ -929,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
@@ -1141,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
@@ -1329,6 +1337,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -111,6 +111,7 @@ def _register_attention_processors_metadata():
|
||||
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
|
||||
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
@@ -158,6 +159,14 @@ def _register_attention_processors_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# ZSingleStreamAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=ZSingleStreamAttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -179,6 +188,7 @@ def _register_transformer_blocks_metadata():
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
@@ -312,6 +322,15 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# ZImage
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=ZImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
@@ -338,4 +357,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -81,6 +81,7 @@ if is_torch_available():
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
"Flux2LoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -113,6 +114,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
Flux2LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
|
||||
@@ -2265,3 +2265,89 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
for sl in range(num_single_layers):
|
||||
single_block_prefix = f"single_blocks.{sl}"
|
||||
attn_prefix = f"single_transformer_blocks.{sl}.attn"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear1.{lora_key}.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear2.{lora_key}.weight"
|
||||
)
|
||||
|
||||
for dl in range(num_double_layers):
|
||||
transformer_block_prefix = f"transformer_blocks.{dl}"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
for attn_type in attn_types:
|
||||
attn_prefix = f"{transformer_block_prefix}.attn"
|
||||
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
|
||||
fused_qkv_weight = original_state_dict.pop(qkv_key)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
diff_attn_proj_keys = (
|
||||
["to_q", "to_k", "to_v"]
|
||||
if attn_type == "img_attn"
|
||||
else ["add_q_proj", "add_k_proj", "add_v_proj"]
|
||||
)
|
||||
for proj_key in diff_attn_proj_keys:
|
||||
converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
|
||||
[fused_qkv_weight]
|
||||
)
|
||||
else:
|
||||
sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
|
||||
if attn_type == "img_attn":
|
||||
converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
else:
|
||||
converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
|
||||
proj_mappings = [
|
||||
("img_attn.proj", "attn.to_out.0"),
|
||||
("txt_attn.proj", "attn.to_add_out"),
|
||||
]
|
||||
for org_proj, diff_proj in proj_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
mlp_mappings = [
|
||||
("img_mlp.0", "ff.linear_in"),
|
||||
("img_mlp.2", "ff.linear_out"),
|
||||
("txt_mlp.0", "ff_context.linear_in"),
|
||||
("txt_mlp.2", "ff_context.linear_out"),
|
||||
]
|
||||
for org_mlp, diff_mlp in mlp_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -45,6 +45,7 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
@@ -5084,6 +5085,209 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if is_ai_toolkit:
|
||||
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -62,6 +62,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from .single_file_utils import (
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
convert_flux2_transformer_checkpoint_to_diffusers,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
@@ -162,6 +163,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Flux2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"net.blocks.0.self_attn.q_proj.weight",
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -189,6 +190,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
||||
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
|
||||
model_type = "flux-2-dev"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
@@ -3647,3 +3652,168 @@ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Image and text input projections
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
# Timestep and guidance embeddings
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
|
||||
# Modulation parameters
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
# Final output layer
|
||||
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
|
||||
# Handle fused QKV projections separately as we need to break into Q, K, V projections
|
||||
"img_attn.norm.query_norm": "attn.norm_q",
|
||||
"img_attn.norm.key_norm": "attn.norm_k",
|
||||
"img_attn.proj": "attn.to_out.0",
|
||||
"img_mlp.0": "ff.linear_in",
|
||||
"img_mlp.2": "ff.linear_out",
|
||||
"txt_attn.norm.query_norm": "attn.norm_added_q",
|
||||
"txt_attn.norm.key_norm": "attn.norm_added_k",
|
||||
"txt_attn.proj": "attn.to_add_out",
|
||||
"txt_mlp.0": "ff_context.linear_in",
|
||||
"txt_mlp.2": "ff_context.linear_out",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
|
||||
"linear1": "attn.to_qkv_mlp_proj",
|
||||
"linear2": "attn.to_out",
|
||||
"norm.query_norm": "attn.norm_q",
|
||||
"norm.key_norm": "attn.norm_k",
|
||||
}
|
||||
|
||||
def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
# Mapping:
|
||||
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
|
||||
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
|
||||
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
|
||||
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
|
||||
new_prefix = "single_transformer_blocks"
|
||||
if "single_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight
|
||||
if ".weight" not in key:
|
||||
return
|
||||
|
||||
# If adaLN_modulation is in the key, swap scale and shift parameters
|
||||
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
|
||||
if "adaLN_modulation" in key:
|
||||
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
|
||||
# Assume all such keys are in the AdaLayerNorm key map
|
||||
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
|
||||
new_key = ".".join([new_key_without_param_type, param_type])
|
||||
|
||||
swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
|
||||
state_dict[new_key] = swapped_weight
|
||||
|
||||
return
|
||||
|
||||
def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
new_prefix = "transformer_blocks"
|
||||
if "double_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
if "qkv" in within_block_name:
|
||||
fused_qkv_weight = state_dict.pop(key)
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
if "img" in modality_block_name:
|
||||
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.to_q"
|
||||
new_k_name = "attn.to_k"
|
||||
new_v_name = "attn.to_v"
|
||||
elif "txt" in modality_block_name:
|
||||
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.add_q_proj"
|
||||
new_k_name = "attn.add_k_proj"
|
||||
new_v_name = "attn.add_v_proj"
|
||||
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
|
||||
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
|
||||
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
|
||||
state_dict[new_q_key] = to_q_weight
|
||||
state_dict[new_k_key] = to_k_weight
|
||||
state_dict[new_v_key] = to_v_weight
|
||||
else:
|
||||
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaLN_modulation": convert_ada_layer_norm_weights,
|
||||
"double_blocks": convert_flux2_double_stream_blocks,
|
||||
"single_blocks": convert_flux2_single_stream_blocks,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
@@ -92,6 +93,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
@@ -110,6 +112,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -140,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLFlux2,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
@@ -190,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -218,6 +223,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -105,7 +105,7 @@ class AttentionMixin:
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
|
||||
module.fuse_projections()
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
@@ -114,13 +114,14 @@ class AttentionMixin:
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
|
||||
module.unfuse_projections()
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
_default_processor_cls = None
|
||||
_available_processors = []
|
||||
_supports_qkv_fusion = True
|
||||
fused_projections = False
|
||||
|
||||
def set_processor(self, processor: AttentionProcessor) -> None:
|
||||
@@ -248,6 +249,14 @@ class AttentionModuleMixin:
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
"""
|
||||
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
|
||||
# single stream blocks are always fused)
|
||||
if not self._supports_qkv_fusion:
|
||||
logger.debug(
|
||||
f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op."
|
||||
)
|
||||
return
|
||||
|
||||
# Skip if already fused
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
@@ -307,6 +316,11 @@ class AttentionModuleMixin:
|
||||
"""
|
||||
Unfuse the query, key, and value projections back to separate projections.
|
||||
"""
|
||||
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
|
||||
# single stream blocks are always fused)
|
||||
if not self._supports_qkv_fusion:
|
||||
return
|
||||
|
||||
# Skip if not fused
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
@@ -18,7 +18,7 @@ import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -160,16 +160,13 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
# - CP with sage attention, flex, xformers, other missing backends
|
||||
# - Add support for normal and CP training with backends that don't support it yet
|
||||
|
||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
@@ -191,6 +188,7 @@ class AttentionBackendName(str, Enum):
|
||||
|
||||
# `sageattention`
|
||||
SAGE = "sage"
|
||||
SAGE_HUB = "sage_hub"
|
||||
SAGE_VARLEN = "sage_varlen"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||
@@ -264,7 +262,13 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
)
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -420,8 +424,8 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
# TODO: add support Hub variant of varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
@@ -1350,6 +1354,38 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -1431,6 +1467,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1444,6 +1481,9 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1938,6 +1978,38 @@ def _sage_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -4,6 +4,7 @@ from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_flux2 import AutoencoderKLFlux2
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
@@ -0,0 +1,546 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
)
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
force_upcast (`bool`, *optional*, default to `True`):
|
||||
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
|
||||
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
||||
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
||||
mid_block will only have resnet blocks
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512,
|
||||
),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 32,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 1024, # YiYi notes: not sure
|
||||
force_upcast: bool = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
mid_block_add_attention: bool = True,
|
||||
batch_norm_eps: float = 1e-4,
|
||||
batch_norm_momentum: float = 0.1,
|
||||
patch_size: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||
|
||||
self.bn = nn.BatchNorm2d(
|
||||
math.prod(patch_size) * latent_channels,
|
||||
eps=batch_norm_eps,
|
||||
momentum=batch_norm_momentum,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
deprecation_message = (
|
||||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
||||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
||||
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
||||
)
|
||||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
moments = torch.cat(result_rows, dim=2)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
if self.config.use_post_quant_conv:
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
@@ -26,6 +26,7 @@ if is_torch_available():
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
@@ -44,3 +45,4 @@ if is_torch_available():
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_animate import WanAnimateTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
from .transformer_z_image import ZImageTransformer2DModel
|
||||
|
||||
@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
|
||||
return key_img, value_img
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
|
||||
# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
@@ -137,7 +137,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12660
|
||||
parallel_config=None,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -150,7 +151,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12660
|
||||
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12660
|
||||
# We need to disable the splitting of encoder_hidden_states because
|
||||
# the image_encoder consistently generates 257 tokens for image_embed. This causes
|
||||
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
|
||||
# after concatenation—to be indivisible by the number of devices in the CP.
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,908 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
if attn.fused_projections:
|
||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class Flux2SwiGLU(nn.Module):
|
||||
"""
|
||||
Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
|
||||
layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = self.gate_fn(x1) * x2
|
||||
return x
|
||||
|
||||
|
||||
class Flux2FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: float = 3.0,
|
||||
inner_dim: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out or dim
|
||||
|
||||
# Flux2SwiGLU will reduce the dimension by half
|
||||
self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
|
||||
self.act_fn = Flux2SwiGLU()
|
||||
self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_in(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.linear_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Flux2AttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# QK Norm
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
# Handle the attention logic
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
|
||||
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
||||
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
||||
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_mult_factor: int = 2,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
|
||||
self.mlp_mult_factor = mlp_mult_factor
|
||||
|
||||
# Fused QKV projections + MLP input projection
|
||||
self.to_qkv_mlp_proj = torch.nn.Linear(
|
||||
self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
|
||||
)
|
||||
self.mlp_act_fn = Flux2SwiGLU()
|
||||
|
||||
# QK Norm
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
# Fused attention output projection + MLP output projection
|
||||
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class Flux2SingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 3.0,
|
||||
eps: float = 1e-6,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
|
||||
# Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
|
||||
# is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
|
||||
# for a visual depiction of this type of transformer block.
|
||||
self.attn = Flux2ParallelSelfAttention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=bias,
|
||||
out_bias=bias,
|
||||
eps=eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
mlp_mult_factor=2,
|
||||
processor=Flux2ParallelSelfAttnProcessor(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_hidden_states: bool = False,
|
||||
text_seq_len: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
||||
# concatenated
|
||||
if encoder_hidden_states is not None:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
mod_shift, mod_scale, mod_gate = temb_mod_params
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
||||
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + mod_gate * attn_output
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
if split_hidden_states:
|
||||
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
||||
return encoder_hidden_states, hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 3.0,
|
||||
eps: float = 1e-6,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
|
||||
self.attn = Flux2Attention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=bias,
|
||||
added_proj_bias=bias,
|
||||
out_bias=bias,
|
||||
eps=eps,
|
||||
processor=Flux2AttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Modulation parameters shape: [1, 1, self.dim]
|
||||
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
||||
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
||||
|
||||
# Img stream
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
|
||||
|
||||
# Conditioning txt stream
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
|
||||
|
||||
# Attention on concatenated img + txt stream
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the image stream (`hidden_states`).
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
# Process attention outputs for the text stream (`encoder_hidden_states`).
|
||||
context_attn_output = c_gate_msa * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class Flux2PosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
# Expected ids shape: [S, len(self.axes_dim)]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
|
||||
for i in range(len(self.axes_dim)):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[..., i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
return time_guidance_emb
|
||||
|
||||
|
||||
class Flux2Modulation(nn.Module):
|
||||
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
|
||||
super().__init__()
|
||||
self.mod_param_sets = mod_param_sets
|
||||
|
||||
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
||||
mod = self.act_fn(temb)
|
||||
mod = self.linear(mod)
|
||||
|
||||
if mod.ndim == 2:
|
||||
mod = mod.unsqueeze(1)
|
||||
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
||||
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
||||
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
||||
|
||||
|
||||
class Flux2Transformer2DModel(
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
FluxTransformer2DLoadersMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux 2.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `128`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `8`):
|
||||
The number of layers of dual stream DiT blocks to use.
|
||||
num_single_layers (`int`, defaults to `48`):
|
||||
The number of layers of single stream DiT blocks to use.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `48`):
|
||||
The number of attention heads to use.
|
||||
joint_attention_dim (`int`, defaults to `15360`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The number of dimensions to use for the pooled projection.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 128,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 8,
|
||||
num_single_layers: int = 48,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 48,
|
||||
joint_attention_dim: int = 15360,
|
||||
timestep_guidance_channels: int = 256,
|
||||
mlp_ratio: float = 3.0,
|
||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||
rope_theta: int = 2000,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
||||
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
# 2. Combined timestep + guidance embedding
|
||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
||||
)
|
||||
|
||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
||||
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
||||
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
||||
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
||||
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
||||
|
||||
# 4. Input projections
|
||||
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
||||
|
||||
# 5. Double Stream Transformer Blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Flux2TransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
eps=eps,
|
||||
bias=False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 6. Single Stream Transformer Blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Flux2SingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
eps=eps,
|
||||
bias=False,
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 7. Output layers
|
||||
self.norm_out = AdaLayerNormContinuous(
|
||||
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
||||
)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = self.time_guidance_embed(timestep, guidance)
|
||||
|
||||
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)[0]
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# 3. Calculate RoPE embeddings from image and text tokens
|
||||
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
||||
# text prompts of differents lengths. Is this a use case we want to support?
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if is_torch_npu_available():
|
||||
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
||||
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
||||
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
||||
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
concat_rotary_emb = (
|
||||
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
)
|
||||
|
||||
# 4. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
double_stream_mod_img,
|
||||
double_stream_mod_txt,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb_mod_params_img=double_stream_mod_img,
|
||||
temb_mod_params_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 5. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
single_stream_mod,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
temb_mod_params=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 6. Output layers
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -0,0 +1,660 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import RMSNorm
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
mid_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
mid_size,
|
||||
out_size,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
weight_dtype = self.mlp[0].weight.dtype
|
||||
if weight_dtype.is_floating_point:
|
||||
t_freq = t_freq.to(weight_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class ZSingleStreamAttnProcessor:
|
||||
"""
|
||||
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
||||
original Z-ImageAttention module.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
output = attn.to_out[0](hidden_states)
|
||||
if len(attn.to_out) > 1: # dropout
|
||||
output = attn.to_out[1](output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // n_heads,
|
||||
heads=n_heads,
|
||||
qk_norm="rms_norm" if qk_norm else None,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ZSingleStreamAttnProcessor(),
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 256.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (64, 128, 128),
|
||||
):
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
else:
|
||||
# Ensure freqs_cis are on the same device as ids
|
||||
if self.freqs_cis[0].device != device:
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=3840,
|
||||
n_layers=30,
|
||||
n_refiner_layers=2,
|
||||
n_heads=30,
|
||||
n_kv_heads=30,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
axes_lens=[1024, 512, 512],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.all_patch_size = all_patch_size
|
||||
self.all_f_patch_size = all_f_patch_size
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.t_scale = t_scale
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
assert len(all_patch_size) == len(all_f_patch_size)
|
||||
|
||||
all_x_embedder = {}
|
||||
all_final_layer = {}
|
||||
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
||||
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
|
||||
|
||||
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.all_final_layer = nn.ModuleDict(all_final_layer)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
head_dim = dim // n_heads
|
||||
assert head_dim == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
all_cap_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat(
|
||||
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.layers:
|
||||
unified = self._gradient_checkpointing_func(
|
||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer in self.layers:
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
|
||||
return x, {}
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ...pipelines import FluxPipeline
|
||||
from ...pipelines.flux.pipeline_flux_utils import calculate_shift
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -30,7 +31,7 @@ from .modular_pipeline import FluxModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -90,21 +91,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -53,7 +53,7 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwen_utils.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -40,7 +40,7 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -183,7 +183,7 @@ def get_qwen_prompt_embeds_edit_plus(
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -43,7 +43,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
# configuration of guider is.
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -103,7 +103,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -81,10 +81,7 @@ class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
components.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
components.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
|
||||
@@ -43,7 +43,7 @@ from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -126,7 +126,7 @@ def calculate_dimension_from_latents(
|
||||
return num_frames, height, width
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
|
||||
@@ -102,7 +102,7 @@ def encode_image(
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
|
||||
@@ -129,6 +129,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -395,6 +396,7 @@ else:
|
||||
"WanVACEPipeline",
|
||||
"WanAnimatePipeline",
|
||||
]
|
||||
_import_structure["z_image"] = ["ZImagePipeline"]
|
||||
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -654,6 +656,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2Pipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
@@ -824,6 +827,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
from .z_image import ZImagePipeline
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -347,7 +347,7 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -152,7 +152,7 @@ class AnimateDiffPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -424,7 +424,7 @@ class AnimateDiffPipeline(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -195,7 +195,7 @@ class AnimateDiffControlNetPipeline(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -468,7 +468,7 @@ class AnimateDiffControlNetPipeline(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -120,7 +120,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
@@ -147,7 +147,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
@@ -325,7 +325,7 @@ class AnimateDiffSDXLPipeline(
|
||||
else 128
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_utils.StableDiffusionXLMixin.encode_prompt with num_images_per_prompt->num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -648,7 +648,7 @@ class AnimateDiffSDXLPipeline(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
@@ -115,7 +115,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
@@ -204,7 +204,7 @@ class AnimateDiffSparseControlNetPipeline(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -471,7 +471,7 @@ class AnimateDiffSparseControlNetPipeline(
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.SDMixin.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user