Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 463a109619 | |||
| dcfb18a2d3 | |||
| e90927ae1c | |||
| 5261e07674 | |||
| 5ec47e1a5a | |||
| cf76e765f6 | |||
| ac5a1e28fc | |||
| 325a95051b | |||
| 1ec28a2c77 | |||
| de6173c683 | |||
| 8f80dda193 | |||
| cdbf0ad883 | |||
| 5e8415a311 | |||
| 051c8a1c0f | |||
| d54622c267 | |||
| df8dd77817 | |||
| 9f3c0fdcd8 | |||
| 84e16575e4 |
@@ -323,6 +323,8 @@
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/transformer_bria_fibo
|
||||
title: BriaFiboTransformer2DModel
|
||||
- local: api/models/bria_transformer
|
||||
title: BriaTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
@@ -469,6 +471,8 @@
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/bria_fibo
|
||||
title: Bria Fibo
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
@@ -525,8 +529,6 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kandinsky5
|
||||
title: Kandinsky 5
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
@@ -634,6 +636,8 @@
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BriaFiboTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
|
||||
|
||||
## BriaFiboTransformer2DModel
|
||||
|
||||
[[autodoc]] BriaFiboTransformer2DModel
|
||||
@@ -0,0 +1,45 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Bria Fibo
|
||||
|
||||
Text-to-image models have mastered imagination - but not control. FIBO changes that.
|
||||
|
||||
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
|
||||
|
||||
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
|
||||
|
||||
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
|
||||
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
|
||||
|
||||
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
|
||||
|
||||
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
|
||||
## BriaPipeline
|
||||
|
||||
[[autodoc]] BriaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
+4
-4
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 5.0
|
||||
# Kandinsky 5.0 Video
|
||||
|
||||
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
|
||||
|
||||
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
|
||||
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
|
||||
|
||||
pipe.transformer.set_attention_backend(
|
||||
"flex"
|
||||
) # <--- Set attention backend to Flex
|
||||
) # <--- Sett attention bakend to Flex
|
||||
pipe.transformer.compile(
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=True
|
||||
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### Diffusion Distilled model
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
|
||||
|
||||
```python
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
|
||||
@@ -198,6 +198,7 @@ else:
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
"BriaTransformer2DModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
@@ -430,6 +431,7 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
@@ -901,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
@@ -1103,6 +1106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
|
||||
@@ -2213,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
|
||||
@@ -4940,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
|
||||
@@ -84,6 +84,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
@@ -174,6 +175,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .transformers import (
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
|
||||
@@ -649,6 +649,86 @@ def _(
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
def _native_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
# Native attention does not return_lse
|
||||
if return_lse:
|
||||
raise ValueError("Native attention does not support return_lse=True")
|
||||
|
||||
# used for backward pass
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.attn_mask = attn_mask
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.enable_gqa = enable_gqa
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _native_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
|
||||
# forward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
@@ -1523,6 +1603,7 @@ def _native_flex_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -1538,18 +1619,35 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op=_native_attention_forward_op,
|
||||
backward_op=_native_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -147,14 +147,13 @@ class AutoModel(ConfigMixin):
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -205,7 +204,6 @@ class AutoModel(ConfigMixin):
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
@@ -286,11 +286,9 @@ class Decoder(nn.Module):
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
@@ -298,7 +296,6 @@ class Decoder(nn.Module):
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
|
||||
@@ -18,6 +18,7 @@ if is_torch_available():
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_bria import BriaTransformer2DModel
|
||||
from .transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from .transformer_chroma import ChromaTransformer2DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
|
||||
@@ -0,0 +1,655 @@
|
||||
# Copyright (c) Bria.ai. All rights reserved.
|
||||
#
|
||||
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
||||
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
||||
#
|
||||
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
||||
# indicate if changes were made, and do not use the material for commercial purposes.
|
||||
#
|
||||
# See the license for further details.
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_bria import BriaAttnProcessor
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
|
||||
if attn.fused_projections:
|
||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
|
||||
class BriaFiboAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "BriaFiboAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
|
||||
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = BriaFiboAttnProcessor
|
||||
_available_processors = [BriaFiboAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
context_pre_only: Optional[bool] = None,
|
||||
pre_only: bool = False,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class BriaFiboEmbedND(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BriaFiboSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
processor = BriaAttnProcessor()
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BriaFiboTextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear(caption)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
|
||||
class BriaFiboTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
self.attn = BriaFiboAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=BriaFiboAttnProcessor(),
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class BriaFiboTimesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.time_theta = time_theta
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
max_period=self.time_theta,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class BriaFiboTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = BriaFiboTimesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
Parameters:
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
||||
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
||||
...
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = None,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
rope_theta=10000,
|
||||
time_theta=10000,
|
||||
text_encoder_dim: int = 2048,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BriaFiboTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BriaFiboSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
caption_projection = [
|
||||
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
|
||||
for i in range(self.config.num_layers + self.config.num_single_layers)
|
||||
]
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
text_encoder_layers: list = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
if guidance:
|
||||
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if len(txt_ids.shape) == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if len(img_ids.shape) == 3:
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
new_text_encoder_layers = []
|
||||
for i, text_encoder_layer in enumerate(text_encoder_layers):
|
||||
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
|
||||
new_text_encoder_layers.append(text_encoder_layer)
|
||||
text_encoder_layers = new_text_encoder_layers
|
||||
|
||||
block_id = 0
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
current_text_encoder_layer = text_encoder_layers[block_id]
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
|
||||
)
|
||||
block_id += 1
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
current_text_encoder_layer = text_encoder_layers[block_id]
|
||||
encoder_hidden_states = torch.cat(
|
||||
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
|
||||
)
|
||||
block_id += 1
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -164,7 +164,11 @@ class AutoOffloadStrategy:
|
||||
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
try:
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
|
||||
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
@@ -699,6 +703,8 @@ class ComponentsManager:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
||||
|
||||
# TODO: add a warning if mem_get_info isn't available on `device`.
|
||||
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
@@ -598,7 +598,7 @@ class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
|
||||
and getattr(block_state, "image_width", None) is not None
|
||||
):
|
||||
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
|
||||
img_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
None, image_latent_height // 2, image_latent_width // 2, device, dtype
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
@@ -141,7 +141,7 @@ class FluxKontextLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
|
||||
@@ -95,7 +95,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
@@ -143,10 +143,6 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def __init__(self, _auto_resize=True):
|
||||
self._auto_resize = _auto_resize
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
@@ -167,7 +163,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("image")]
|
||||
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
@@ -195,7 +191,8 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
img = images[0]
|
||||
image_height, image_width = components.image_processor.get_default_height_width(img)
|
||||
aspect_ratio = image_width / image_height
|
||||
if self._auto_resize:
|
||||
_auto_resize = block_state._auto_resize
|
||||
if _auto_resize:
|
||||
# Kontext is trained on specific resolutions, using one of them is recommended
|
||||
_, image_width, image_height = min(
|
||||
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
||||
|
||||
@@ -112,6 +112,10 @@ class FluxTextInputStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
|
||||
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, -1
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -305,15 +305,15 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"local_dir",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
@@ -331,11 +331,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
|
||||
block_kwargs = {
|
||||
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
|
||||
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
|
||||
}
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
@@ -2131,8 +2130,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
component_load_kwargs[key] = value["default"]
|
||||
try:
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create component '{name}': {e}")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
# Register all components at once
|
||||
self.register_components(**components_to_register)
|
||||
|
||||
@@ -128,6 +128,7 @@ else:
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_bria_fibo import BriaFiboPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,838 @@
|
||||
# Copyright (c) Bria.ai. All rights reserved.
|
||||
#
|
||||
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
||||
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
||||
#
|
||||
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
||||
# indicate if changes were made, and do not use the material for commercial purposes.
|
||||
#
|
||||
# See the license for further details.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
|
||||
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
|
||||
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from diffusers import BriaFiboPipeline
|
||||
from diffusers.modular_pipelines import ModularPipeline
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
|
||||
|
||||
pipe = BriaFiboPipeline.from_pretrained(
|
||||
"briaai/FIBO",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
with torch.inference_mode():
|
||||
# 1. Create a prompt to generate an initial image
|
||||
output = vlm_pipe(prompt="a beautiful dog")
|
||||
json_prompt_generate = output.values["json_prompt"]
|
||||
|
||||
# Generate the image from the structured json prompt
|
||||
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
|
||||
results_generate.images[0].save("image_generate.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BriaFiboPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer (`BriaFiboTransformer2DModel`):
|
||||
The transformer model for 2D diffusion modeling.
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
|
||||
Scheduler to be used with `transformer` to denoise the encoded latents.
|
||||
vae (`AutoencoderKLWan`):
|
||||
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
|
||||
text_encoder (`SmolLM3ForCausalLM`):
|
||||
Text encoder for processing input prompts.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
Tokenizer used for processing the input text prompts for the text_encoder.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: BriaFiboTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKLWan,
|
||||
text_encoder: SmolLM3ForCausalLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.default_sample_size = 64
|
||||
|
||||
def get_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 2048,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if not prompt:
|
||||
raise ValueError("`prompt` must be a non-empty string or list of strings.")
|
||||
|
||||
batch_size = len(prompt)
|
||||
bot_token_id = 128000
|
||||
|
||||
text_encoder_device = device if device is not None else torch.device("cpu")
|
||||
if not isinstance(text_encoder_device, torch.device):
|
||||
text_encoder_device = torch.device(text_encoder_device)
|
||||
|
||||
if all(p == "" for p in prompt):
|
||||
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
tokenized = self.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = tokenized.input_ids.to(text_encoder_device)
|
||||
attention_mask = tokenized.attention_mask.to(text_encoder_device)
|
||||
|
||||
if any(p == "" for p in prompt):
|
||||
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
|
||||
input_ids[empty_rows] = bot_token_id
|
||||
attention_mask[empty_rows] = 1
|
||||
|
||||
encoder_outputs = self.text_encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_outputs.hidden_states
|
||||
|
||||
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
|
||||
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
hidden_states = tuple(
|
||||
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
|
||||
)
|
||||
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
|
||||
|
||||
return prompt_embeds, hidden_states, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
|
||||
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
|
||||
batch_size, seq_len, dim = prompt_embeds.shape
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
|
||||
if max_tokens < seq_len:
|
||||
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
|
||||
|
||||
if max_tokens > seq_len:
|
||||
pad_length = max_tokens - seq_len
|
||||
padding = torch.zeros(
|
||||
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
|
||||
|
||||
mask_padding = torch.zeros(
|
||||
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
|
||||
|
||||
return prompt_embeds, attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 3000,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
guidance_scale (`float`):
|
||||
Guidance scale for classifier free guidance.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
prompt_attention_mask = None
|
||||
negative_prompt_attention_mask = None
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
|
||||
|
||||
if guidance_scale > 1:
|
||||
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
|
||||
negative_prompt = ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
|
||||
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
# Pad to longest
|
||||
if prompt_attention_mask is not None:
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
if negative_prompt_attention_mask is not None:
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
|
||||
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
|
||||
)
|
||||
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
|
||||
|
||||
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
||||
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
||||
)
|
||||
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
|
||||
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
|
||||
else:
|
||||
max_tokens = prompt_embeds.shape[1]
|
||||
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
||||
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
||||
)
|
||||
negative_prompt_layers = None
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt_layers,
|
||||
negative_prompt_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@staticmethod
|
||||
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels)
|
||||
latents = latents.permute(0, 3, 1, 2)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.permute(0, 2, 3, 1)
|
||||
latents = latents.reshape(batch_size, height * width, num_channels_latents)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
do_patching=False,
|
||||
):
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if do_patching:
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
else:
|
||||
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@staticmethod
|
||||
def _prepare_attention_mask(attention_mask):
|
||||
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
|
||||
|
||||
# convert to 0 - keep, -inf ignore
|
||||
attention_matrix = torch.where(
|
||||
attention_matrix == 1, 0.0, -torch.inf
|
||||
) # Apply -inf to ignored tokens for nulling softmax score
|
||||
return attention_matrix
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 3000,
|
||||
do_patching=False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
|
||||
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt_layers,
|
||||
negative_prompt_layers,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if guidance_scale > 1:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_layers = [
|
||||
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
|
||||
]
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
|
||||
self.transformer.single_transformer_blocks
|
||||
)
|
||||
if len(prompt_layers) >= total_num_layers_transformer:
|
||||
# remove first layers
|
||||
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
|
||||
else:
|
||||
# duplicate last layer
|
||||
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
|
||||
|
||||
# 5. Prepare latent variables
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
if do_patching:
|
||||
num_channels_latents = int(num_channels_latents / 4)
|
||||
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
prompt_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
do_patching,
|
||||
)
|
||||
|
||||
latent_attention_mask = torch.ones(
|
||||
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
|
||||
)
|
||||
if guidance_scale > 1:
|
||||
latent_attention_mask = latent_attention_mask.repeat(2, 1)
|
||||
|
||||
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
|
||||
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
|
||||
|
||||
if self._joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
self._joint_attention_kwargs["attention_mask"] = attention_mask
|
||||
|
||||
# Adapt scheduler to dynamic shifting (resolution dependent)
|
||||
|
||||
if do_patching:
|
||||
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
|
||||
else:
|
||||
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
mu = calculate_shift(
|
||||
seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
|
||||
# Init sigmas and timesteps according to shift size
|
||||
# This changes the scheduler in-place according to the dynamic scheduling
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps=num_inference_steps,
|
||||
device=device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# Support old different diffusers versions
|
||||
if len(latent_image_ids.shape) == 3:
|
||||
latent_image_ids = latent_image_ids[0]
|
||||
|
||||
if len(text_ids.shape) == 3:
|
||||
text_ids = text_ids[0]
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(
|
||||
device=latent_model_input.device, dtype=latent_model_input.dtype
|
||||
)
|
||||
|
||||
# This is predicts "v" from flow-matching or eps from diffusion
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
text_encoder_layers=prompt_layers,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale > 1:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
if do_patching:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
else:
|
||||
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
|
||||
|
||||
latents = latents.unsqueeze(dim=2)
|
||||
latents_device = latents[0].device
|
||||
latents_dtype = latents[0].dtype
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents_device, latents_dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents_device, latents_dtype
|
||||
)
|
||||
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
|
||||
latents_scaled = torch.cat(latents_scaled, dim=0)
|
||||
image = []
|
||||
for scaled_latent in latents_scaled:
|
||||
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
|
||||
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
|
||||
image.append(curr_image)
|
||||
if len(image) == 1:
|
||||
image = image[0]
|
||||
else:
|
||||
image = np.stack(image, axis=0)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return BriaFiboPipelineOutput(images=image)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 3000:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class BriaFiboPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for BriaFibo pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
+1
-1
@@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -588,6 +588,21 @@ class AutoModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BriaFiboTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BriaTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -482,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class BriaFiboPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class BriaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -254,6 +254,7 @@ def get_cached_module_file(
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
local_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
||||
@@ -332,6 +333,7 @@ def get_cached_module_file(
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
local_dir=local_dir,
|
||||
)
|
||||
submodule = "git"
|
||||
module_file = pretrained_model_name_or_path + ".py"
|
||||
@@ -355,6 +357,8 @@ def get_cached_module_file(
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
local_dir=local_dir,
|
||||
revision=revision,
|
||||
token=token,
|
||||
)
|
||||
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
|
||||
@@ -415,6 +419,7 @@ def get_cached_module_file(
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
local_dir=local_dir,
|
||||
)
|
||||
return os.path.join(full_submodule, module_file)
|
||||
|
||||
@@ -431,7 +436,7 @@ def get_class_from_dynamic_module(
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
local_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Extracts a class from a module file, present in the local folder or repository of a model.
|
||||
@@ -496,5 +501,6 @@ def get_class_from_dynamic_module(
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
local_dir=local_dir,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module)
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BriaFiboTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 48
|
||||
num_image_channels = 3
|
||||
height = width = 16
|
||||
sequence_length = 32
|
||||
embedding_dim = 64
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (256, 48)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 48,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 64,
|
||||
"text_encoder_dim": 32,
|
||||
"pooled_projection_dim": None,
|
||||
"axes_dims_rope": [0, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaFiboTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.modular_pipelines import (
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
FluxModularPipeline,
|
||||
ModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class FluxModularTests:
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
|
||||
class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = super().get_pipeline(components_manager, torch_dtype)
|
||||
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
|
||||
# fixed constants instead of
|
||||
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
|
||||
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
inputs["height"] = 8
|
||||
inputs["width"] = 8
|
||||
return inputs
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = PIL.Image.new("RGB", (32, 32), 0)
|
||||
_ = inputs.pop("strength")
|
||||
inputs["image"] = image
|
||||
inputs["height"] = 8
|
||||
inputs["width"] = 8
|
||||
inputs["max_area"] = 8 * 8
|
||||
inputs["_auto_resize"] = False
|
||||
return inputs
|
||||
+4
-16
@@ -21,24 +21,12 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
BriaFiboPipeline,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BriaFiboPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = False
|
||||
test_group_offloading = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = BriaFiboTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=16,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=64,
|
||||
text_encoder_dim=32,
|
||||
pooled_projection_dim=None,
|
||||
axes_dims_rope=[0, 4, 4],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=160,
|
||||
decoder_base_dim=256,
|
||||
num_res_blocks=2,
|
||||
out_channels=12,
|
||||
patch_size=2,
|
||||
scale_factor_spatial=16,
|
||||
scale_factor_temporal=4,
|
||||
temperal_downsample=[False, True, True],
|
||||
z_dim=16,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "{'text': 'A painting of a squirrel eating a burger'}",
|
||||
"negative_prompt": "bad, ugly",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@unittest.skip(reason="will not be supported due to dim-fusion")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
def test_bria_fibo_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = "a different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height
|
||||
expected_width = width
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
Reference in New Issue
Block a user