5ffb73d4ae
* add vae * Initial commit for Flux 2 Transformer implementation * add pipeline part * small edits to the pipeline and conversion * update conversion script * fix * up up * finish pipeline * Remove Flux IP Adapter logic for now * Remove deprecated 3D id logic * Remove ControlNet logic for now * Add link to ViT-22B paper as reference for parallel transformer blocks such as the Flux 2 single stream block * update pipeline * Don't use biases for input projs and output AdaNorm * up * Remove bias for double stream block text QKV projections * Add script to convert Flux 2 transformer to diffusers * make style and make quality * fix a few things. * allow sft files to go. * fix image processor * fix batch * style a bit * Fix some bugs in Flux 2 transformer implementation * Fix dummy input preparation and fix some test bugs * fix dtype casting in timestep guidance module. * resolve conflicts., * remove ip adapter stuff. * Fix Flux 2 transformer consistency test * Fix bug in Flux2TransformerBlock (double stream block) * Get remaining Flux 2 transformer tests passing * make style; make quality; make fix-copies * remove stuff. * fix type annotaton. * remove unneeded stuff from tests * tests * up * up * add sf support * Remove unused IP Adapter and ControlNet logic from transformer (#9) * copied from * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: apolinário <joaopaulo.passos@gmail.com> * up * up * up * up * up * Refactor Flux2Attention into separate classes for double stream and single stream attention * Add _supports_qkv_fusion to AttentionModuleMixin to allow subclasses to disable QKV fusion * Have Flux2ParallelSelfAttention inherit from AttentionModuleMixin with _supports_qkv_fusion=False * Log debug message when calling fuse_projections on a AttentionModuleMixin subclass that does not support QKV fusion * Address review comments * Update src/diffusers/pipelines/flux2/pipeline_flux2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * up * Remove maybe_allow_in_graph decorators for Flux 2 transformer blocks (#12) * up * support ostris loras. (#13) * up * update schdule * up * up (#17) * add training scripts (#16) * add training scripts Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com> * model cpu offload in validation. * add flux.2 readme * add img2img and tests * cpu offload in log validation * Apply suggestions from code review * fix * up * fixes * remove i2i training tests for now. --------- Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com> Co-authored-by: linoytsaban <linoy@huggingface.co> * up --------- Co-authored-by: yiyixuxu <yixu310@gmail.com> Co-authored-by: Daniel Gu <dgu8957@gmail.com> Co-authored-by: yiyi@huggingface.co <yiyi@ip-10-53-87-203.ec2.internal> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: apolinário <joaopaulo.passos@gmail.com> Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal> Co-authored-by: Linoy Tsaban <linoytsaban@gmail.com> Co-authored-by: linoytsaban <linoy@huggingface.co>
191 lines
6.6 KiB
Python
191 lines
6.6 KiB
Python
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration
|
|
|
|
from diffusers import (
|
|
AutoencoderKLFlux2,
|
|
FlowMatchEulerDiscreteScheduler,
|
|
Flux2Pipeline,
|
|
Flux2Transformer2DModel,
|
|
)
|
|
|
|
from ...testing_utils import (
|
|
torch_device,
|
|
)
|
|
from ..test_pipelines_common import (
|
|
PipelineTesterMixin,
|
|
check_qkv_fused_layers_exist,
|
|
)
|
|
|
|
|
|
class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|
pipeline_class = Flux2Pipeline
|
|
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
|
batch_params = frozenset(["prompt"])
|
|
|
|
test_xformers_attention = False
|
|
test_layerwise_casting = True
|
|
test_group_offloading = True
|
|
|
|
supports_dduf = False
|
|
|
|
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
|
torch.manual_seed(0)
|
|
transformer = Flux2Transformer2DModel(
|
|
patch_size=1,
|
|
in_channels=4,
|
|
num_layers=num_layers,
|
|
num_single_layers=num_single_layers,
|
|
attention_head_dim=16,
|
|
num_attention_heads=2,
|
|
joint_attention_dim=16,
|
|
timestep_guidance_channels=256, # Hardcoded in original code
|
|
axes_dims_rope=[4, 4, 4, 4],
|
|
)
|
|
|
|
config = Mistral3Config(
|
|
text_config={
|
|
"model_type": "mistral",
|
|
"vocab_size": 32000,
|
|
"hidden_size": 16,
|
|
"intermediate_size": 37,
|
|
"max_position_embeddings": 512,
|
|
"num_attention_heads": 4,
|
|
"num_hidden_layers": 1,
|
|
"num_key_value_heads": 2,
|
|
"rms_norm_eps": 1e-05,
|
|
"rope_theta": 1000000000.0,
|
|
"sliding_window": None,
|
|
"bos_token_id": 2,
|
|
"eos_token_id": 3,
|
|
"pad_token_id": 4,
|
|
},
|
|
vision_config={
|
|
"model_type": "pixtral",
|
|
"hidden_size": 16,
|
|
"num_hidden_layers": 1,
|
|
"num_attention_heads": 4,
|
|
"intermediate_size": 37,
|
|
"image_size": 30,
|
|
"patch_size": 6,
|
|
"num_channels": 3,
|
|
},
|
|
bos_token_id=2,
|
|
eos_token_id=3,
|
|
pad_token_id=4,
|
|
model_dtype="mistral3",
|
|
image_seq_length=4,
|
|
vision_feature_layer=-1,
|
|
image_token_index=1,
|
|
)
|
|
torch.manual_seed(0)
|
|
text_encoder = Mistral3ForConditionalGeneration(config)
|
|
tokenizer = AutoProcessor.from_pretrained(
|
|
"hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor"
|
|
)
|
|
|
|
torch.manual_seed(0)
|
|
vae = AutoencoderKLFlux2(
|
|
sample_size=32,
|
|
in_channels=3,
|
|
out_channels=3,
|
|
down_block_types=("DownEncoderBlock2D",),
|
|
up_block_types=("UpDecoderBlock2D",),
|
|
block_out_channels=(4,),
|
|
layers_per_block=1,
|
|
latent_channels=1,
|
|
norm_num_groups=1,
|
|
use_quant_conv=False,
|
|
use_post_quant_conv=False,
|
|
)
|
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler()
|
|
|
|
return {
|
|
"scheduler": scheduler,
|
|
"text_encoder": text_encoder,
|
|
"tokenizer": tokenizer,
|
|
"transformer": transformer,
|
|
"vae": vae,
|
|
}
|
|
|
|
def get_dummy_inputs(self, device, seed=0):
|
|
if str(device).startswith("mps"):
|
|
generator = torch.manual_seed(seed)
|
|
else:
|
|
generator = torch.Generator(device="cpu").manual_seed(seed)
|
|
|
|
inputs = {
|
|
"prompt": "a dog is dancing",
|
|
"generator": generator,
|
|
"num_inference_steps": 2,
|
|
"guidance_scale": 5.0,
|
|
"height": 8,
|
|
"width": 8,
|
|
"max_sequence_length": 8,
|
|
"output_type": "np",
|
|
"text_encoder_out_layers": (1,),
|
|
}
|
|
return inputs
|
|
|
|
def test_fused_qkv_projections(self):
|
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
components = self.get_dummy_components()
|
|
pipe = self.pipeline_class(**components)
|
|
pipe = pipe.to(device)
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
inputs = self.get_dummy_inputs(device)
|
|
image = pipe(**inputs).images
|
|
original_image_slice = image[0, -3:, -3:, -1]
|
|
|
|
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
|
# to the pipeline level.
|
|
pipe.transformer.fuse_qkv_projections()
|
|
self.assertTrue(
|
|
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
|
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
|
)
|
|
|
|
inputs = self.get_dummy_inputs(device)
|
|
image = pipe(**inputs).images
|
|
image_slice_fused = image[0, -3:, -3:, -1]
|
|
|
|
pipe.transformer.unfuse_qkv_projections()
|
|
inputs = self.get_dummy_inputs(device)
|
|
image = pipe(**inputs).images
|
|
image_slice_disabled = image[0, -3:, -3:, -1]
|
|
|
|
self.assertTrue(
|
|
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
|
("Fusion of QKV projections shouldn't affect the outputs."),
|
|
)
|
|
self.assertTrue(
|
|
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
|
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
|
)
|
|
self.assertTrue(
|
|
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
|
("Original outputs should match when fused QKV projections are disabled."),
|
|
)
|
|
|
|
def test_flux_image_output_shape(self):
|
|
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
height_width_pairs = [(32, 32), (72, 57)]
|
|
for height, width in height_width_pairs:
|
|
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
|
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
|
|
|
inputs.update({"height": height, "width": width})
|
|
image = pipe(**inputs).images[0]
|
|
output_height, output_width, _ = image.shape
|
|
self.assertEqual(
|
|
(output_height, output_width),
|
|
(expected_height, expected_width),
|
|
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
|
)
|