Compare commits

..

9 Commits

Author SHA1 Message Date
Dhruv Nair 6dda0ecf67 Merge branch 'main' into sf-modular 2025-10-28 23:03:45 +05:30
galbria 84e16575e4 Bria fibo (#12545)
* Bria FIBO pipeline

* style fixs

* fix CR

* Refactor BriaFibo classes and update pipeline parameters

- Updated BriaFiboAttnProcessor and BriaFiboAttention classes to reflect changes from Flux equivalents.
- Modified the _unpack_latents method in BriaFiboPipeline to improve clarity.
- Increased the default max_sequence_length to 3000 and added a new optional parameter do_patching.
- Cleaned up test_pipeline_bria_fibo.py by removing unused imports and skipping unsupported tests.

* edit the docs of FIBO

* Remove unused BriaFibo imports and update CPU offload method in BriaFiboPipeline

* Refactor FIBO classes to BriaFibo naming convention

- Updated class names from FIBO to BriaFibo for consistency across the module.
- Modified instances of FIBOEmbedND, FIBOTimesteps, TextProjection, and TimestepProjEmbeddings to reflect the new naming.
- Ensured all references in the BriaFiboTransformer2DModel are updated accordingly.

* Add BriaFiboTransformer2DModel import to transformers module

* Remove unused BriaFibo imports from modular pipelines and add BriaFiboTransformer2DModel and BriaFiboPipeline classes to dummy objects for enhanced compatibility with torch and transformers.

* Update BriaFibo classes with copied documentation and fix import typo in pipeline module

- Added documentation comments indicating the source of copied code in BriaFiboTransformerBlock and _pack_latents methods.
- Corrected the import statement for BriaFiboPipeline in the pipelines module.

* Remove unused BriaFibo imports from __init__.py to streamline modular pipelines.

* Refactor documentation comments in BriaFibo classes to indicate inspiration from existing implementations

- Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to reflect that the code is inspired by other modules rather than copied.
- Enhanced clarity on the origins of the methods to maintain proper attribution.

* change Inspired by to Based on

* add reference link and fix trailing whitespace

* Add BriaFiboTransformer2DModel documentation and update comments in BriaFibo classes

- Introduced a new documentation file for BriaFiboTransformer2DModel.
- Updated comments in BriaFiboAttnProcessor, BriaFiboAttention, and BriaFiboPipeline to clarify the origins of the code, indicating copied sources for better attribution.

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2025-10-28 16:27:48 +05:30
Dhruv Nair fb1766ee4f Merge branch 'sf-modular' of https://github.com/huggingface/diffusers into sf-modular 2025-10-23 19:00:13 +02:00
Dhruv Nair 8733fef39d update 2025-10-23 18:56:14 +02:00
github-actions[bot] a707e314ad Apply style fixes 2025-10-20 09:44:50 +00:00
DN6 c1c0e9a481 update 2025-09-29 12:35:55 +05:30
DN6 0a7bde9200 update 2025-09-24 16:33:09 +05:30
DN6 af48d815d8 update 2025-09-24 16:31:07 +05:30
DN6 bea02ccba3 update 2025-09-18 23:31:07 +05:30
27 changed files with 2248 additions and 56 deletions
+4
View File
@@ -323,6 +323,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
title: BriaFiboTransformer2DModel
- local: api/models/bria_transformer
title: BriaTransformer2DModel
- local: api/models/chroma_transformer
@@ -469,6 +471,8 @@
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BriaFiboTransformer2DModel
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
## BriaFiboTransformer2DModel
[[autodoc]] BriaFiboTransformer2DModel
+45
View File
@@ -0,0 +1,45 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo
Text-to-image models have mastered imagination - but not control. FIBO changes that.
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaPipeline
[[autodoc]] BriaPipeline
- all
- __call__
+1 -1
View File
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
```py
guider_spec = t2i_pipeline.get_component_spec("guider")
guider_spec.default_creation_method="from_pretrained"
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
guider_spec.subfolder="pag_guider"
pag_guider = guider_spec.load()
t2i_pipeline.update_components(guider=pag_guider)
@@ -313,14 +313,14 @@ unet_spec
ComponentSpec(
name='unet',
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
repo='RunDiffusion/Juggernaut-XL-v9',
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
subfolder='unet',
variant='fp16',
default_creation_method='from_pretrained'
)
# modify to load from a different repository
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
# load component with modified spec
unet = unet_spec.load(torch_dtype=torch.float16)
+1 -1
View File
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
```py
guider_spec = t2i_pipeline.get_component_spec("guider")
guider_spec.default_creation_method="from_pretrained"
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
guider_spec.subfolder="pag_guider"
pag_guider = guider_spec.load()
t2i_pipeline.update_components(guider=pag_guider)
@@ -313,14 +313,14 @@ unet_spec
ComponentSpec(
name='unet',
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
repo='RunDiffusion/Juggernaut-XL-v9',
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
subfolder='unet',
variant='fp16',
default_creation_method='from_pretrained'
)
# 修改以从不同的仓库加载
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
# 使用修改后的规范加载组件
unet = unet_spec.load(torch_dtype=torch.float16)
+4
View File
@@ -198,6 +198,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
@@ -430,6 +431,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
@@ -901,6 +903,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
@@ -1103,6 +1106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
+8 -1
View File
@@ -387,6 +387,14 @@ def is_valid_url(url):
return False
def _is_single_file_path_or_url(pretrained_model_name_or_path):
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
return False
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
return bool(repo_id and weight_name)
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
if not is_valid_url(pretrained_model_name_or_path):
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
@@ -398,7 +406,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
match = re.match(pattern, pretrained_model_name_or_path)
if not match:
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
return repo_id, weights_name
repo_id = f"{match.group(1)}/{match.group(2)}"
+2
View File
@@ -84,6 +84,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -174,6 +175,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -18,6 +18,7 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
@@ -0,0 +1,655 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
class BriaFiboAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "BriaFiboAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = BriaFiboAttnProcessor
_available_processors = [BriaFiboAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class BriaFiboEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class BriaFiboSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = BriaAttnProcessor()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class BriaFiboTextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
@maybe_allow_in_graph
# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
class BriaFiboTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = BriaFiboAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=BriaFiboAttnProcessor(),
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class BriaFiboTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class BriaFiboTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = BriaFiboTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
...
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
rope_theta=10000,
time_theta=10000,
text_encoder_dim: int = 2048,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BriaFiboTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
BriaFiboSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
caption_projection = [
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
for i in range(self.config.num_layers + self.config.num_single_layers)
]
self.caption_projection = nn.ModuleList(caption_projection)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
text_encoder_layers: list = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 3:
txt_ids = txt_ids[0]
if len(img_ids.shape) == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
new_text_encoder_layers = []
for i, text_encoder_layer in enumerate(text_encoder_layers):
text_encoder_layer = self.caption_projection[i](text_encoder_layer)
new_text_encoder_layers.append(text_encoder_layer)
text_encoder_layers = new_text_encoder_layers
block_id = 0
for index_block, block in enumerate(self.transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
current_text_encoder_layer = text_encoder_layers[block_id]
encoder_hidden_states = torch.cat(
[encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
)
block_id += 1
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -217,11 +217,6 @@ MELLON_OUTPUT_PARAMS = {
"display": "output",
"type": "controlnet",
},
"doc": {
"label": "Doc",
"display": "output",
"type": "string",
},
}
@@ -702,3 +697,67 @@ class MellonNodeConfig(PushToHubMixin):
blocks_names=blocks_names,
node_type=node_type,
)
# Minimal modular registry for Mellon node configs
class ModularMellonNodeRegistry:
"""Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""
def __init__(self):
self._registry = {}
self._initialized = False
def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
if not self._initialized:
_initialize_registry(self)
self._registry[pipeline_cls] = node_params
def get(self, pipeline_cls: type) -> MellonNodeConfig:
if not self._initialized:
_initialize_registry(self)
return self._registry.get(pipeline_cls, None)
def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
if not self._initialized:
_initialize_registry(self)
return self._registry
def _register_preset_node_types(
pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
):
"""Register all node-type presets for a given pipeline class from a params map."""
node_configs = {}
for node_type, spec in params_map.items():
node_config = MellonNodeConfig(
inputs=spec.get("inputs", []),
model_inputs=spec.get("model_inputs", []),
outputs=spec.get("outputs", []),
blocks_names=spec.get("block_names", []),
node_type=node_type,
)
node_configs[node_type] = node_config
registry.register(pipeline_cls, node_configs)
def _initialize_registry(registry: ModularMellonNodeRegistry):
"""Initialize the registry and register all available pipeline configs."""
print("Initializing registry")
registry._initialized = True
try:
from .qwenimage.modular_pipeline import QwenImageModularPipeline
from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP
_register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
except Exception:
raise Exception("Failed to register QwenImageModularPipeline")
try:
from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP
_register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
except Exception:
raise Exception("Failed to register StableDiffusionXLModularPipeline")
@@ -361,7 +361,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
collection: Optional[str] = None,
) -> "ModularPipeline":
"""
create a ModularPipeline, optionally accept modular_repo to load from hub.
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
"""
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
@@ -1562,7 +1562,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
library, class_name = value
component_spec_dict = {
"repo": pretrained_model_name_or_path,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"subfolder": name,
"type_hint": (library, class_name),
}
@@ -1619,8 +1619,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
Path to a pretrained pipeline configuration. It will first try to load config from
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
during initialization.
non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
will be set to None during initialization.
trust_remote_code (`bool`, optional):
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
@@ -1790,7 +1790,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
library, class_name = None, None
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
# e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
# "type_hint": ("diffusers", "UNet2DConditionModel"),
# "subfolder": "unet",
# "variant": None,
@@ -2094,8 +2094,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
`variant`, `revision`, etc.
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
"""
if names is None:
@@ -2354,10 +2356,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- "type_hint": Tuple[str, str]
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
- All loading fields defined by `component_spec.loading_fields()`, typically:
- "repo": Optional[str]
The model repository (e.g., "stabilityai/stable-diffusion-xl").
- "pretrained_model_name_or_path": Optional[str]
The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
- "subfolder": Optional[str]
A subfolder within the repo where this component lives.
A subfolder within the pretrained_model_name_or_path where this component lives.
- "variant": Optional[str]
An optional variant identifier for the model.
- "revision": Optional[str]
@@ -2374,12 +2376,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Example:
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
UNet2DConditionModel >>> spec = ComponentSpec(
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
subfolder="subfolder", ... variant=None, ... revision=None, ...
default_creation_method="from_pretrained",
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
variant=None, ... revision=None, ... default_creation_method="from_pretrained",
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
"variant": None, "revision": None,
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
"subfolder": "subfolder", "variant": None, "revision": None,
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
"subfolder": "subfolder", "variant": None, "revision": None,
}
"""
if component_spec.default_creation_method != "from_pretrained":
@@ -2408,10 +2412,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- "type_hint": Tuple[str, str]
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
- All loading fields defined by `component_spec.loading_fields()`, typically:
- "repo": Optional[str]
- "pretrained_model_name_or_path": Optional[str]
The model repository (e.g., "stabilityai/stable-diffusion-xl").
- "subfolder": Optional[str]
A subfolder within the repo where this component lives.
A subfolder within the pretrained_model_name_or_path where this component lives.
- "variant": Optional[str]
An optional variant identifier for the model.
- "revision": Optional[str]
@@ -2428,11 +2432,20 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
ComponentSpec: A reconstructed ComponentSpec object.
Example:
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
ComponentSpec(
name="unet", type_hint=UNet2DConditionModel, config=None,
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
revision=None, default_creation_method="from_pretrained"
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
ComponentSpec(
name="unet", type_hint=UNet2DConditionModel, config=None,
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
revision=None, default_creation_method="from_pretrained"
)
"""
# make a shallow copy so we can pop() safely
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging
@@ -80,10 +81,10 @@ class ComponentSpec:
type_hint: Type of the component (e.g. UNet2DConditionModel)
description: Optional description of the component
config: Optional config dict for __init__ creation
repo: Optional repo path for from_pretrained creation
subfolder: Optional subfolder in repo
variant: Optional variant in repo
revision: Optional revision in repo
pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
subfolder: Optional subfolder in pretrained_model_name_or_path
variant: Optional variant in pretrained_model_name_or_path
revision: Optional revision in pretrained_model_name_or_path
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
"""
@@ -91,13 +92,20 @@ class ComponentSpec:
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
# Deprecated
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
def __post_init__(self):
repo_value = self.repo
if repo_value is not None and self.pretrained_model_name_or_path is None:
object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
@@ -182,8 +190,8 @@ class ComponentSpec:
@property
def load_id(self) -> str:
"""
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
segments).
Unique identifier for this spec's pretrained load, composed of
pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
"""
if self.default_creation_method == "from_config":
return "null"
@@ -197,12 +205,13 @@ class ComponentSpec:
Decode a load_id string back into a dictionary of loading fields and values.
Args:
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
where None values are represented as "null"
Returns:
Dict mapping loading field names to their values. e.g. {
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
"pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
"revision": "revision"
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
component not created with `load` method).
"""
@@ -259,34 +268,45 @@ class ComponentSpec:
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
# select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
if pretrained_model_name_or_path is None:
raise ValueError(
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
"`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
)
is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
if is_single_file and self.type_hint is None:
raise ValueError(
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
)
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully
self.type_hint = component.__class__
else:
# determine load method
load_method = (
getattr(self.type_hint, "from_single_file")
if is_single_file
else getattr(self.type_hint, "from_pretrained")
)
try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo
self.pretrained_model_name_or_path = pretrained_model_name_or_path
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
@@ -0,0 +1,95 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# mellon nodes
QwenImage_NODE_TYPES_PARAMS_MAP = {
"controlnet": {
"inputs": [
"control_image",
"controlnet_conditioning_scale",
"control_guidance_start",
"control_guidance_end",
"height",
"width",
],
"model_inputs": [
"controlnet",
"vae",
],
"outputs": [
"controlnet_out",
],
"block_names": ["controlnet_vae_encoder"],
},
"denoise": {
"inputs": [
"embeddings",
"width",
"height",
"seed",
"num_inference_steps",
"guidance_scale",
"image_latents",
"strength",
"controlnet",
],
"model_inputs": [
"unet",
"guider",
"scheduler",
],
"outputs": [
"latents",
"latents_preview",
],
"block_names": ["denoise"],
},
"vae_encoder": {
"inputs": [
"image",
"width",
"height",
],
"model_inputs": [
"vae",
],
"outputs": [
"image_latents",
],
},
"text_encoder": {
"inputs": [
"prompt",
"negative_prompt",
],
"model_inputs": [
"text_encoders",
],
"outputs": [
"embeddings",
],
},
"decoder": {
"inputs": [
"latents",
],
"model_inputs": [
"vae",
],
"outputs": [
"images",
],
},
}
@@ -0,0 +1,99 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
SDXL_NODE_TYPES_PARAMS_MAP = {
"controlnet": {
"inputs": [
"control_image",
"controlnet_conditioning_scale",
"control_guidance_start",
"control_guidance_end",
"height",
"width",
],
"model_inputs": [
"controlnet",
],
"outputs": [
"controlnet_out",
],
"block_names": [None],
},
"denoise": {
"inputs": [
"embeddings",
"width",
"height",
"seed",
"num_inference_steps",
"guidance_scale",
"image_latents",
"strength",
# custom adapters coming in as inputs
"controlnet",
# ip_adapter is optional and custom; include if available
"ip_adapter",
],
"model_inputs": [
"unet",
"guider",
"scheduler",
],
"outputs": [
"latents",
"latents_preview",
],
"block_names": ["denoise"],
},
"vae_encoder": {
"inputs": [
"image",
"width",
"height",
],
"model_inputs": [
"vae",
],
"outputs": [
"image_latents",
],
"block_names": ["vae_encoder"],
},
"text_encoder": {
"inputs": [
"prompt",
"negative_prompt",
],
"model_inputs": [
"text_encoders",
],
"outputs": [
"embeddings",
],
"block_names": ["text_encoder"],
},
"decoder": {
"inputs": [
"latents",
],
"model_inputs": [
"vae",
],
"outputs": [
"images",
],
"block_names": ["decode"],
},
}
+2
View File
@@ -128,6 +128,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -562,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,838 @@
# Copyright (c) Bria.ai. All rights reserved.
#
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
#
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
# indicate if changes were made, and do not use the material for commercial purposes.
#
# See the license for further details.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Example:
```python
import torch
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
pipe = BriaFiboPipeline.from_pretrained(
"briaai/FIBO",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
with torch.inference_mode():
# 1. Create a prompt to generate an initial image
output = vlm_pipe(prompt="a beautiful dog")
json_prompt_generate = output.values["json_prompt"]
# Generate the image from the structured json prompt
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
results_generate.images[0].save("image_generate.png")
```
"""
class BriaFiboPipeline(DiffusionPipeline):
r"""
Args:
transformer (`BriaFiboTransformer2DModel`):
The transformer model for 2D diffusion modeling.
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
Scheduler to be used with `transformer` to denoise the encoded latents.
vae (`AutoencoderKLWan`):
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
text_encoder (`SmolLM3ForCausalLM`):
Text encoder for processing input prompts.
tokenizer (`AutoTokenizer`):
Tokenizer used for processing the input text prompts for the text_encoder.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: BriaFiboTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKLWan,
text_encoder: SmolLM3ForCausalLM,
tokenizer: AutoTokenizer,
):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 64
def get_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
if not prompt:
raise ValueError("`prompt` must be a non-empty string or list of strings.")
batch_size = len(prompt)
bot_token_id = 128000
text_encoder_device = device if device is not None else torch.device("cpu")
if not isinstance(text_encoder_device, torch.device):
text_encoder_device = torch.device(text_encoder_device)
if all(p == "" for p in prompt):
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
attention_mask = torch.ones_like(input_ids)
else:
tokenized = self.tokenizer(
prompt,
padding="longest",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = tokenized.input_ids.to(text_encoder_device)
attention_mask = tokenized.attention_mask.to(text_encoder_device)
if any(p == "" for p in prompt):
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
input_ids[empty_rows] = bot_token_id
attention_mask[empty_rows] = 1
encoder_outputs = self.text_encoder(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_outputs.hidden_states
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hidden_states = tuple(
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
)
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
return prompt_embeds, hidden_states, attention_mask
@staticmethod
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
batch_size, seq_len, dim = prompt_embeds.shape
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
else:
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if max_tokens < seq_len:
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
if max_tokens > seq_len:
pad_length = max_tokens - seq_len
padding = torch.zeros(
(batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
mask_padding = torch.zeros(
(batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
return prompt_embeds, attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 3000,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
guidance_scale (`float`):
Guidance scale for classifier free guidance.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
prompt_attention_mask = None
negative_prompt_attention_mask = None
if prompt_embeds is None:
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
if guidance_scale > 1:
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
negative_prompt = ""
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
# Pad to longest
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if negative_prompt_embeds is not None:
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
)
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
)
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
else:
max_tokens = prompt_embeds.shape[1]
prompt_embeds, prompt_attention_mask = self.pad_embedding(
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
)
negative_prompt_layers = None
dtype = self.text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
return (
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@staticmethod
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels)
latents = latents.permute(0, 3, 1, 2)
return latents
@staticmethod
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
latents = latents.permute(0, 2, 3, 1)
latents = latents.reshape(batch_size, height * width, num_channels_latents)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
do_patching=False,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if do_patching:
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
else:
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@staticmethod
def _prepare_attention_mask(attention_mask):
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
# convert to 0 - keep, -inf ignore
attention_matrix = torch.where(
attention_matrix == 1, 0.0, -torch.inf
) # Apply -inf to ignored tokens for nulling softmax score
return attention_matrix
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 3000,
do_patching=False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
Examples:
Returns:
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_layers,
negative_prompt_layers,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
num_images_per_prompt=num_images_per_prompt,
lora_scale=lora_scale,
)
prompt_batch_size = prompt_embeds.shape[0]
if guidance_scale > 1:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_layers = [
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
]
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
self.transformer.single_transformer_blocks
)
if len(prompt_layers) >= total_num_layers_transformer:
# remove first layers
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
else:
# duplicate last layer
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
if do_patching:
num_channels_latents = int(num_channels_latents / 4)
latents, latent_image_ids = self.prepare_latents(
prompt_batch_size,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
do_patching,
)
latent_attention_mask = torch.ones(
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
)
if guidance_scale > 1:
latent_attention_mask = latent_attention_mask.repeat(2, 1)
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
if self._joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
self._joint_attention_kwargs["attention_mask"] = attention_mask
# Adapt scheduler to dynamic shifting (resolution dependent)
if do_patching:
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
else:
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
# Init sigmas and timesteps according to shift size
# This changes the scheduler in-place according to the dynamic scheduling
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Support old different diffusers versions
if len(latent_image_ids.shape) == 3:
latent_image_ids = latent_image_ids[0]
if len(text_ids.shape) == 3:
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(
device=latent_model_input.device, dtype=latent_model_input.dtype
)
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_encoder_layers=prompt_layers,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if guidance_scale > 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
if do_patching:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
else:
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
latents = latents.unsqueeze(dim=2)
latents_device = latents[0].device
latents_dtype = latents[0].dtype
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents_device, latents_dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_device, latents_dtype
)
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
latents_scaled = torch.cat(latents_scaled, dim=0)
image = []
for scaled_latent in latents_scaled:
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
image.append(curr_image)
if len(image) == 1:
image = image[0]
else:
image = np.stack(image, axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return BriaFiboPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if max_sequence_length is not None and max_sequence_length > 3000:
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class BriaFiboPipelineOutput(BaseOutput):
"""
Output class for BriaFibo pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
+15
View File
@@ -588,6 +588,21 @@ class AutoModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BriaFiboTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BriaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -482,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class BriaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import BriaFiboTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}
@property
def input_shape(self):
return (16, 16)
@property
def output_shape(self):
return (256, 48)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 48,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"joint_attention_dim": 64,
"text_encoder_dim": 32,
"pooled_projection_dim": None,
"axes_dims_rope": [0, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@@ -51,7 +51,7 @@ class SDXLModularTests:
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular"
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset(
[
"prompt",
@@ -66,7 +66,9 @@ class SDXLModularTests:
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline = self.pipeline_blocks_class().init_pipeline(
self.pretrained_model_name_or_path, components_manager=components_manager
)
pipeline.load_components(torch_dtype=torch_dtype)
return pipeline
@@ -157,7 +159,7 @@ class SDXLModularIPAdapterTests:
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe = blocks.init_pipeline(self.pretrained_model_name_or_path)
pipe.load_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -0,0 +1,139 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=160,
decoder_base_dim=256,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "{'text': 'A painting of a squirrel eating a burger'}",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"output_type": "np",
}
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)