Add Photon model and pipeline support (#12456)

* Add Photon model and pipeline support

This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests

* just store the T5Gemma encoder

* enhance_vae_properties if vae is provided only

* remove autocast for text encoder forwad

* BF16 example

* conditioned CFG

* remove enhance vae and use vae.config directly when possible

* move PhotonAttnProcessor2_0 in transformer_photon

* remove einops dependency and now inherits from AttentionMixin

* unify the structure of the forward block

* update doc

* update doc

* fix T5Gemma loading from hub

* fix timestep shift

* remove lora support from doc

* Rename EmbedND for PhotoEmbedND

* remove modulation dataclass

* put _attn_forward and _ffn_forward logic in PhotonBlock's forward

* renam LastLayer for FinalLayer

* remove lora related code

* rename vae_spatial_compression_ratio for vae_scale_factor

* support prompt_embeds in call

* move xattention conditionning out computation out of the denoising loop

* add negative prompts

* Use _import_structure for lazy loading

* make quality + style

* add pipeline test + corresponding fixes

* utility function that determines the default resolution given the VAE

* Refactor PhotonAttention to match Flux pattern

* built-in RMSNorm

* Revert accidental .gitignore change

* parameter names match the standard diffusers conventions

* renaming and remove unecessary attributes setting

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* quantization example

* added doc to toctree

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* use dispatch_attention_fn for multiple attention backend support

* naming changes

* make fix copy

* Update docs/source/en/api/pipelines/photon.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Add PhotonTransformer2DModel to TYPE_CHECKING imports

* make fix-copies

* Use Tuple instead of tuple

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* restrict the version of transformers

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/photon/test_pipeline_photon.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/photon/test_pipeline_photon.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* change | for Optional

* fix nits.

* use typing Dict

---------

Co-authored-by: davidb <davidb@worker-10.soperator-worker-svc.soperator.svc.cluster.local>
Co-authored-by: David Briand <david@photoroom.com>
Co-authored-by: davidb <davidb@worker-8.soperator-worker-svc.soperator.svc.cluster.local>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
David Bertoin
2025-10-21 17:25:55 +02:00
committed by GitHub
parent b3e56e71fb
commit cefc2cf82d
16 changed files with 2501 additions and 0 deletions
@@ -0,0 +1,83 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PhotonTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16, 16)
@property
def output_shape(self):
return (16, 16, 16)
def prepare_dummy_input(self, height=16, width=16):
batch_size = 1
num_latent_channels = 16
sequence_length = 16
embedding_dim = 1792
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 16,
"patch_size": 2,
"context_in_dim": 1792,
"hidden_size": 1792,
"mlp_ratio": 3.5,
"num_heads": 28,
"depth": 4, # Smaller depth for testing
"axes_dim": [32, 32],
"theta": 10_000,
}
inputs_dict = self.prepare_dummy_input()
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"PhotonTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
if __name__ == "__main__":
unittest.main()
View File
@@ -0,0 +1,265 @@
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PhotonPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
@classmethod
def setUpClass(cls):
# Ensure PhotonPipeline has an _execution_device property expected by __call__
if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property):
try:
setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
except Exception:
pass
def get_dummy_components(self):
torch.manual_seed(0)
transformer = PhotonTransformer2DModel(
patch_size=1,
in_channels=4,
context_in_dim=8,
hidden_size=8,
mlp_ratio=2.0,
num_heads=2,
depth=1,
axes_dim=[2, 2],
)
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0,
scaling_factor=1.0,
).eval()
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
tokenizer.model_max_length = 64
torch.manual_seed(0)
encoder_params = {
"vocab_size": tokenizer.vocab_size,
"hidden_size": 8,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"attention_bias": False,
"attention_dropout": 0.0,
"dropout_rate": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
"rms_norm_eps": 1e-06,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 4,
"rope_theta": 10000.0,
"sliding_window": 4096,
}
encoder_config = T5GemmaModuleConfig(**encoder_params)
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
text_encoder = T5GemmaEncoder(text_encoder_config)
return {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
return {
"prompt": "",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"output_type": "pt",
"use_resolution_binning": False,
}
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = PhotonPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
try:
pipe.register_to_config(_execution_device="cpu")
except Exception:
pass
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.zeros(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
components = self.get_dummy_components()
pipe = PhotonPipeline(**components)
pipe = pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
try:
pipe.register_to_config(_execution_device="cpu")
except Exception:
pass
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs("cpu")
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
_ = pipe(**inputs)[0]
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
def to_np_local(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
return tensor
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
def test_inference_with_autoencoder_dc(self):
"""Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
device = "cpu"
components = self.get_dummy_components()
torch.manual_seed(0)
vae_dc = AutoencoderDC(
in_channels=3,
latent_channels=4,
attention_head_dim=2,
encoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
decoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
encoder_block_out_channels=(8, 8),
decoder_block_out_channels=(8, 8),
encoder_qkv_multiscales=((), (5,)),
decoder_qkv_multiscales=((), (5,)),
encoder_layers_per_block=(1, 1),
decoder_layers_per_block=(1, 1),
upsample_block_type="interpolate",
downsample_block_type="stride_conv",
decoder_norm_types="rms_norm",
decoder_act_fns="silu",
).eval()
components["vae"] = vae_dc
pipe = PhotonPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
expected_scale_factor = vae_dc.spatial_compression_ratio
self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.zeros(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)