Compare commits

...

18 Commits

Author SHA1 Message Date
YiYi Xu 463a109619 Merge branch 'main' into custom-rev 2025-11-04 08:13:55 -10:00
Linoy Tsaban dcfb18a2d3 [LoRA] add support for more Qwen LoRAs (#12581)
* fix bug when offload and cache_latents both enabled

* fix
2025-11-04 14:27:25 +02:00
DN6 e90927ae1c update 2025-11-04 17:17:19 +05:30
DN6 5261e07674 update 2025-11-04 17:06:59 +05:30
DN6 5ec47e1a5a update 2025-11-04 13:52:29 +05:30
DN6 cf76e765f6 update 2025-11-04 11:43:03 +05:30
Sayak Paul ac5a1e28fc [docs] sort doc (#12586)
sort doc
2025-11-04 10:26:07 +05:30
Lev Novitskiy 325a95051b Kandinsky 5.0 Docs fixes (#12582)
* add transformer pipeline first version

* updates

* fix 5sec generation

* rewrite Kandinsky5T2VPipeline to diffusers style

* add multiprompt support

* remove prints in pipeline

* add nabla attention

* Wrap Transformer in Diffusers style

* fix license

* fix prompt type

* add gradient checkpointing and peft support

* add usage example

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* remove unused imports

* add 10 second models support

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove no_grad and simplified prompt paddings

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved template to __init__

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* moved sdps inside processor

* remove oneline function

* remove reset_dtype methods

* Transformer: move all methods to forward

* separated prompt encoding

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* fixed

* style +copies

* Update src/diffusers/models/transformers/transformer_kandinsky.py

Co-authored-by: Charles <charles@huggingface.co>

* more

* Apply suggestions from code review

* add lora loader doc

* add compiled Nabla Attention

* all needed changes for 10 sec models are added!

* add docs

* Apply style fixes

* update docs

* add kandinsky5 to toctree

* add tests

* fix tests

* Apply style fixes

* update tests

* minor docs refactoring

* refactor Kandinsky 5.0 Vide docs

* Update docs/source/en/_toctree.yml

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Charles <charles@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-03 14:38:07 -10:00
Wang, Yi 1ec28a2c77 ulysses enabling in native attention path (#12563)
* ulysses enabling in native attention path

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* address review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add supports_context_parallel for native attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update templated attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-03 11:48:20 -10:00
YiYi Xu de6173c683 [modular]pass hub_kwargs to load_config (#12577)
pass hub_kwargs to load_config
2025-11-03 09:44:42 -10:00
Sayak Paul 8f80dda193 [tests] add tests for flux modular (t2i, i2i, kontext) (#12566)
* start flux modular tests.

* up

* add kontext

* up

* up

* up

* Update src/diffusers/modular_pipelines/flux/denoise.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* up

* up

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-02 10:51:11 +05:30
YiYi Xu cdbf0ad883 [modular] better warn message (#12573)
better warn message
2025-11-01 18:45:09 -10:00
Dhruv Nair 5e8415a311 Fix custom code loading in Automodel (#12571)
update
2025-11-01 17:04:31 -10:00
Friedrich Schöller 051c8a1c0f Fix Stable Diffusion 3.x pooled prompt embedding with multiple images (#12306) 2025-10-31 10:25:13 -10:00
Dhruv Nair d54622c267 [Modular] Allow custom blocks to be saved to local_dir (#12381)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-10-31 13:47:02 +05:30
Dhruv Nair df8dd77817 [Modular] Fix for custom block kwargs (#12561)
update
2025-10-31 00:14:24 +05:30
Pavle Padjin 9f3c0fdcd8 Avoiding graph break by changing the way we infer dtype in vae.decoder (#12512)
* Changing the way we infer dtype to avoid force evaluation of lazy tensors

* changing way to infer dtype to ensure type consistency

* more robust infering of dtype

* removing the upscale dtype entirely
2025-10-30 08:39:40 +05:30
galbria 84e16575e4 Bria fibo (#12545)
* Bria FIBO pipeline

* style fixs

* fix CR

* Refactor BriaFibo classes and update pipeline parameters

- Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents.
- Modified the _unpack_latents method in BriaFiboPipeline to improve clarity.
- Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching.
- Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests.

* edit the docs of FIBO

* Remove unused BriaFibo imports and update CPU offload method in BriaFiboPipeline

* Refactor FIBO classes to BriaFibo naming convention

- Updated class names from FIBO to BriaFibo for consistency across the module.
- Modified instances of FIBOEmbedND, FIBOTimesteps, TextProjection, and TimestepProjEmbeddings to reflect the new naming.
- Ensured all references in the BriaFiboTransformer2DModel are updated accordingly.

* Add BriaFiboTransformer2DModel import to transformers module

* Remove unused BriaFibo imports from modular pipelines and add BriaFiboTransformer2DModel and BriaFiboPipeline classes to dummy objects for enhanced compatibility with torch and transformers.

* Update BriaFibo classes with copied documentation and fix import typo in pipeline module

- Added documentation comments indicating the source of copied code in BriaFiboTransformerBlock and _pack_latents methods.
- Corrected the import statement for BriaFiboPipeline in the pipelines module.

* Remove unused BriaFibo imports from __init__.py to streamline modular pipelines.

* Refactor documentation comments in BriaFibo classes to indicate inspiration from existing implementations

- Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to reflect that the code is inspired by other modules rather than copied.
- Enhanced clarity on the origins of the methods to maintain proper attribution.

* change Inspired by to Based on

* add reference link and fix trailing whitespace

* Add BriaFiboTransformer2DModel documentation and update comments in BriaFibo classes

- Introduced a new documentation file for BriaFiboTransformer2DModel.
- Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to clarify the origins of the code, indicating copied sources for better attribution.

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2025-10-28 16:27:48 +05:30
39 changed files with 2196 additions and 66 deletions
+6 -2
View File
@@ -323,6 +323,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
@@ -469,6 +471,8 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -525,8 +529,6 @@
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kandinsky5
title: Kandinsky 5
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models
@@ -634,6 +636,8 @@
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/kandinsky5_video
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx_video
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaFiboTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
## BriaFiboTransformer2DModel
[[autodoc]] BriaFiboTransformer2DModel
+45
View File
@@ -0,0 +1,45 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0
# Kandinsky 5.0 Video
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Set attention backend to Flex
) # <--- Sett attention bakend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
+4
View File
@@ -198,6 +198,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
@@ -430,6 +431,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
@@ -901,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
@@ -1103,6 +1106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
@@ -2213,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
state_dict = {convert_key(k): v for k, v in state_dict.items()}
has_default = any("default." in k for k in state_dict)
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
+2 -1
View File
@@ -4940,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
has_default = any("default." in k for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
+2
View File
@@ -84,6 +84,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -174,6 +175,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
+110 -12
View File
@@ -649,6 +649,86 @@ def _(
# ===== Helper functions to use attention backends with templated CP autograd functions =====
def _native_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
# Native attention does not return_lse
if return_lse:
raise ValueError("Native attention does not support return_lse=True")
# used for backward pass
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.attn_mask = attn_mask
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.enable_gqa = enable_gqa
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
return out
def _native_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value = ctx.saved_tensors
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
grad_out_t = grad_out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
)
grad_query = grad_query_t.permute(0, 2, 1, 3)
grad_key = grad_key_t.permute(0, 2, 1, 3)
grad_value = grad_value_t.permute(0, 2, 1, 3)
return grad_query, grad_key, grad_value
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
# forward declaration:
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1523,6 +1603,7 @@ def _native_flex_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
supports_context_parallel=True,
)
def _native_attention(
query: torch.Tensor,
@@ -1538,18 +1619,35 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
else:
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op=_native_attention_forward_op,
backward_op=_native_attention_backward_op,
_parallel_config=_parallel_config,
)
return out
+1 -3
View File
@@ -147,14 +147,13 @@ class AutoModel(ConfigMixin):
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"token",
]
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
library = None
orig_class_name = None
@@ -205,7 +204,6 @@ class AutoModel(ConfigMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
else:
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
-3
View File
@@ -286,11 +286,9 @@ class Decoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -298,7 +296,6 @@ class Decoder(nn.Module):
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -18,6 +18,7 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
@@ -0,0 +1,655 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
class BriaFiboAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "BriaFiboAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = BriaFiboAttnProcessor
_available_processors = [BriaFiboAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class BriaFiboEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class BriaFiboSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = BriaAttnProcessor()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class BriaFiboTextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
@maybe_allow_in_graph
# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
class BriaFiboTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = BriaFiboAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=BriaFiboAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class BriaFiboTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class BriaFiboTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = BriaFiboTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
...
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
text_encoder_dim: int = 2048,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BriaFiboTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
BriaFiboSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
caption_projection = [
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
for i in range(self.config.num_layers + self.config.num_single_layers)
]
self.caption_projection = nn.ModuleList(caption_projection)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
text_encoder_layers: list = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 3:
txt_ids = txt_ids[0]
if len(img_ids.shape) == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
new_text_encoder_layers = []
for i, text_encoder_layer in enumerate(text_encoder_layers):
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
new_text_encoder_layers.append(text_encoder_layer)
text_encoder_layers = new_text_encoder_layers
block_id = 0
for index_block, block in enumerate(self.transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -164,7 +164,11 @@ class AutoOffloadStrategy:
device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
try:
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
except AttributeError:
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
@@ -699,6 +703,8 @@ class ComponentsManager:
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
# TODO: add a warning if mem_get_info isn't available on `device`.
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
remove_hook_from_module(component, recurse=True)
@@ -598,7 +598,7 @@ class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
and getattr(block_state, "image_width", None) is not None
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
@@ -59,7 +59,7 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
),
InputParam(
"guidance",
required=True,
required=False,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
@@ -141,7 +141,7 @@ class FluxKontextLoopDenoiser(ModularPipelineBlocks):
),
InputParam(
"guidance",
required=True,
required=False,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
@@ -95,7 +95,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
default_creation_method="from_config",
),
]
@@ -143,10 +143,6 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux-kontext"
def __init__(self, _auto_resize=True):
self._auto_resize = _auto_resize
super().__init__()
@property
def description(self) -> str:
return (
@@ -167,7 +163,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [InputParam("image")]
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
@property
def intermediate_outputs(self) -> List[OutputParam]:
@@ -195,7 +191,8 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
img = images[0]
image_height, image_width = components.image_processor.get_default_height_width(img)
aspect_ratio = image_width / image_height
if self._auto_resize:
_auto_resize = block_state._auto_resize
if _auto_resize:
# Kontext is trained on specific resolutions, using one of them is recommended
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
@@ -112,6 +112,10 @@ class FluxTextInputStep(ModularPipelineBlocks):
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, -1
)
self.set_block_state(state, block_state)
return components, state
@@ -305,15 +305,15 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"cache_dir",
"force_download",
"local_files_only",
"local_dir",
"proxies",
"resume_download",
"revision",
"subfolder",
"token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path)
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
@@ -331,11 +331,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
block_kwargs = {
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
}
return block_cls(**block_kwargs)
@@ -2131,8 +2130,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_load_kwargs[key] = value["default"]
try:
components_to_register[name] = spec.load(**component_load_kwargs)
except Exception as e:
logger.warning(f"Failed to create component '{name}': {e}")
except Exception:
logger.warning(
f"\nFailed to create component {name}:\n"
f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n\n"
f"{traceback.format_exc()}"
)
# Register all components at once
self.register_components(**components_to_register)
+2
View File
@@ -128,6 +128,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,838 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Example:
```python
import torch
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
pipe = BriaFiboPipeline.from_pretrained(
"briaai/FIBO",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
with torch.inference_mode():
# 1. Create a prompt to generate an initial image
output = vlm_pipe(prompt="a beautiful dog")
json_prompt_generate = output.values["json_prompt"]
# Generate the image from the structured json prompt
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
results_generate.images[0].save("image_generate.png")
```
"""
class BriaFiboPipeline(DiffusionPipeline):
r"""
Args:
transformer (`BriaFiboTransformer2DModel`):
The transformer model for 2D diffusion modeling.
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
Scheduler to be used with `transformer` to denoise the encoded latents.
vae (`AutoencoderKLWan`):
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
text_encoder (`SmolLM3ForCausalLM`):
Text encoder for processing input prompts.
tokenizer (`AutoTokenizer`):
Tokenizer used for processing the input text prompts for the text_encoder.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: BriaFiboTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKLWan,
text_encoder: SmolLM3ForCausalLM,
tokenizer: AutoTokenizer,
):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 64
def get_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
if not prompt:
raise ValueError("`prompt` must be a non-empty string or list of strings.")
batch_size = len(prompt)
bot_token_id = 128000
text_encoder_device = device if device is not None else torch.device("cpu")
if not isinstance(text_encoder_device, torch.device):
text_encoder_device = torch.device(text_encoder_device)
if all(p == "" for p in prompt):
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
attention_mask = torch.ones_like(input_ids)
else:
tokenized = self.tokenizer(
prompt,
padding="longest",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = tokenized.input_ids.to(text_encoder_device)
attention_mask = tokenized.attention_mask.to(text_encoder_device)
if any(p == "" for p in prompt):
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
input_ids[empty_rows] = bot_token_id
attention_mask[empty_rows] = 1
encoder_outputs = self.text_encoder(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_outputs.hidden_states
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hidden_states = tuple(
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
)
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
return prompt_embeds, hidden_states, attention_mask
@staticmethod
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
batch_size, seq_len, dim = prompt_embeds.shape
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
else:
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if max_tokens < seq_len:
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
if max_tokens > seq_len:
pad_length = max_tokens - seq_len
padding = torch.zeros(
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
mask_padding = torch.zeros(
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
return prompt_embeds, attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 3000,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
guidance_scale (`float`):
Guidance scale for classifier free guidance.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
prompt_attention_mask = None
negative_prompt_attention_mask = None
if prompt_embeds is None:
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
if guidance_scale > 1:
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
negative_prompt = ""
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
# Pad to longest
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if negative_prompt_embeds is not None:
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
)
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
)
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
else:
max_tokens = prompt_embeds.shape[1]
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
negative_prompt_layers = None
dtype = self.text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
return (
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@staticmethod
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels)
latents = latents.permute(0, 3, 1, 2)
return latents
@staticmethod
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
latents = latents.permute(0, 2, 3, 1)
latents = latents.reshape(batch_size, height * width, num_channels_latents)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
do_patching=False,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if do_patching:
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
else:
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@staticmethod
def _prepare_attention_mask(attention_mask):
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
# convert to 0 - keep, -inf ignore
attention_matrix = torch.where(
attention_matrix == 1, 0.0, -torch.inf
) # Apply -inf to ignored tokens for nulling softmax score
return attention_matrix
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 3000,
do_patching=False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
Examples:
Returns:
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
num_images_per_prompt=num_images_per_prompt,
lora_scale=lora_scale,
)
prompt_batch_size = prompt_embeds.shape[0]
if guidance_scale > 1:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_layers = [
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
]
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
self.transformer.single_transformer_blocks
)
if len(prompt_layers) >= total_num_layers_transformer:
# remove first layers
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
else:
# duplicate last layer
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if do_patching:
num_channels_latents = int(num_channels_latents / 4)
latents, latent_image_ids = self.prepare_latents(
prompt_batch_size,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
do_patching,
)
latent_attention_mask = torch.ones(
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
)
if guidance_scale > 1:
latent_attention_mask = latent_attention_mask.repeat(2, 1)
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
if self._joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
self._joint_attention_kwargs["attention_mask"] = attention_mask
# Adapt scheduler to dynamic shifting (resolution dependent)
if do_patching:
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
else:
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Support old different diffusers versions
if len(latent_image_ids.shape) == 3:
latent_image_ids = latent_image_ids[0]
if len(text_ids.shape) == 3:
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(
device=latent_model_input.device, dtype=latent_model_input.dtype
)
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_encoder_layers=prompt_layers,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if guidance_scale > 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
if do_patching:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
else:
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
latents = latents.unsqueeze(dim=2)
latents_device = latents[0].device
latents_dtype = latents[0].dtype
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents_device, latents_dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_device, latents_dtype
)
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
latents_scaled = torch.cat(latents_scaled, dim=0)
image = []
for scaled_latent in latents_scaled:
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
image.append(curr_image)
if len(image) == 1:
image = image[0]
else:
image = np.stack(image, axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return BriaFiboPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if max_sequence_length is not None and max_sequence_length > 3000:
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class BriaFiboPipelineOutput(BaseOutput):
"""
Output class for BriaFibo pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
+15
View File
@@ -588,6 +588,21 @@ class AutoModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BriaFiboTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BriaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -482,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class BriaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+7 -1
View File
@@ -254,6 +254,7 @@ def get_cached_module_file(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
local_dir: Optional[str] = None,
):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -332,6 +333,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
local_dir=local_dir,
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
@@ -355,6 +357,8 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
local_dir=local_dir,
revision=revision,
token=token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -415,6 +419,7 @@ def get_cached_module_file(
token=token,
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
)
return os.path.join(full_submodule, module_file)
@@ -431,7 +436,7 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
local_dir: Optional[str] = None,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -496,5 +501,6 @@ def get_class_from_dynamic_module(
token=token,
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
)
return get_class_in_module(class_name, final_module)
@@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import BriaFiboTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}
@property
def input_shape(self):
return (16, 16)
@property
def output_shape(self):
return (256, 48)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 48,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"joint_attention_dim": 64,
"text_encoder_dim": 32,
"pooled_projection_dim": None,
"axes_dims_rope": [0, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
import numpy as np
import PIL
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.modular_pipelines import (
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,
FluxModularPipeline,
ModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class FluxModularTests:
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
repo = "hf-internal-testing/tiny-flux-modular"
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_components(torch_dtype=torch_dtype)
return pipeline
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = super().get_pipeline(components_manager, torch_dtype)
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
# fixed constants instead of
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
return pipeline
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
inputs["height"] = 8
inputs["width"] = 8
return inputs
def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
pipes.append(pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = PIL.Image.new("RGB", (32, 32), 0)
_ = inputs.pop("strength")
inputs["image"] = image
inputs["height"] = 8
inputs["width"] = 8
inputs["max_area"] = 8 * 8
inputs["_auto_resize"] = False
return inputs
@@ -21,24 +21,12 @@ import numpy as np
import torch
from PIL import Image
from diffusers import (
ClassifierFreeGuidance,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from diffusers.loaders import ModularIPAdapterMixin
from ...models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modular_pipelines_common import (
ModularPipelineTesterMixin,
)
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
enable_full_determinism()
@@ -0,0 +1,139 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=160,
decoder_base_dim=256,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "{'text': 'A painting of a squirrel eating a burger'}",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"output_type": "np",
}
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)