Compare commits

..

8 Commits

Author SHA1 Message Date
sayakpaul 6397a67892 up 2025-11-27 10:30:18 +05:30
sayakpaul 82685f2e11 fix system prompts 🤷 2025-11-26 14:41:14 +05:30
Sayak Paul b07bee3148 Merge branch 'main' into flux2-upsample 2025-11-26 12:28:57 +05:30
sayakpaul ceb8a3a2f9 up 2025-11-26 12:27:23 +05:30
sayakpaul 0b1f884459 fix 2025-11-26 12:14:04 +05:30
sayakpaul b4a840698b up 2025-11-26 11:29:29 +05:30
sayakpaul e6a0ab6244 doc 2025-11-26 11:04:34 +05:30
sayakpaul 7350d07bd2 feat: implement caption upsampling for flux.2. 2025-11-26 10:31:36 +05:30
15 changed files with 332 additions and 689 deletions
+6
View File
@@ -26,6 +26,12 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
>
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Caption upsampling
Flux.2 can potentially generate better better outputs with better prompts. We can "upsample"
an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.
The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
@@ -69,10 +69,7 @@ class TimestepEmbedder(nn.Module):
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
t_emb = self.mlp(t_freq)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
@@ -129,10 +126,6 @@ class ZSingleStreamAttnProcessor:
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
# Compute joint attention
hidden_states = dispatch_attention_fn(
query,
@@ -313,10 +306,6 @@ class RopeEmbedder:
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
@@ -328,7 +317,6 @@ class RopeEmbedder:
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
@register_to_config
def __init__(
@@ -565,6 +553,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
t = t * self.t_scale
t = self.t_embedder(t)
adaln_input = t
(
x,
cap_feats,
@@ -582,9 +572,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
# Match t_embedder output dtype to x for layerwise casting compatibility
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Tuple
from typing import List
import PIL.Image
@@ -96,9 +96,9 @@ class Flux2ImageProcessor(VaeImageProcessor):
)
return image
@staticmethod
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size
scale = math.sqrt(target_area / (image_width * image_height))
@@ -106,6 +106,14 @@ class Flux2ImageProcessor(VaeImageProcessor):
height = int(image_height * scale)
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
@staticmethod
def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size
pixel_count = image_width * image_height
if pixel_count <= target_area:
return image
return Flux2ImageProcessor._resize_to_target_area(image, target_area)
def _resize_and_crop(
self,
@@ -136,3 +144,35 @@ class Flux2ImageProcessor(VaeImageProcessor):
bottom = top + height
return image.crop((left, top, right, bottom))
# Taken from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19
@staticmethod
def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image:
"""
Concatenate a list of PIL images horizontally with center alignment and white background.
"""
# If only one image, return a copy of it
if len(images) == 1:
return images[0].copy()
# Convert all images to RGB if not already
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
# Calculate dimensions for horizontal concatenation
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)
# Create new image with white background
background_color = (255, 255, 255)
new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)
# Paste images with center alignment
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
new_img.paste(img, (x_offset, y_offset))
x_offset += img.width
return new_img
+174 -21
View File
@@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput
from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
if is_torch_xla_available():
@@ -56,25 +57,107 @@ EXAMPLE_DOC_STRING = """
```
"""
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
def format_text_input(prompts: List[str], system_message: str = None):
# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
def format_input(
prompts: List[str],
system_message: str = SYSTEM_MESSAGE,
images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
):
"""
Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
to the input.
Args:
prompts: List of text prompts
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
images (optional): List of images to add to the input.
Returns:
List of conversations, where each conversation is a list of message dicts
"""
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
if images is None or len(images) == 0:
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
for prompt in cleaned_txt
else:
assert len(images) == len(prompts), "Number of images must match number of prompts"
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
]
for _ in cleaned_txt
]
for i, (el, images) in enumerate(zip(messages, images)):
# optionally add the images per batch element.
if images is not None:
el.append(
{
"role": "user",
"content": [{"type": "image", "image": image_obj} for image_obj in images],
}
)
# add the text.
el.append(
{
"role": "user",
"content": [{"type": "text", "text": cleaned_txt[i]}],
}
)
return messages
# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19
def _validate_and_process_images(
images: List[List[PIL.Image.Image]] | List[PIL.Image.Image],
image_processor: Flux2ImageProcessor,
upsampling_max_image_size: int,
) -> List[List[PIL.Image.Image]]:
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
if not images:
return []
# Check if it's a list of lists or a list of images
if isinstance(images[0], PIL.Image.Image):
# It's a list of images, convert to list of lists
images = [[im] for im in images]
# potentially concatenate multiple images to reduce the size
images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images]
# cap the pixels
images = [
[
image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size)
for img_i in img_i
]
for img_i in images
]
return images
# Taken from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
@@ -214,9 +297,10 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
self.tokenizer_max_length = 512
self.default_sample_size = 128
# fmt: off
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
# fmt: on
self.system_message = SYSTEM_MESSAGE
self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
@staticmethod
def _get_mistral_3_small_prompt_embeds(
@@ -226,9 +310,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
# fmt: off
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
system_message: str = SYSTEM_MESSAGE,
hidden_states_layers: List[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
@@ -237,7 +319,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt
# Format input messages
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
messages_batch = format_input(prompts=prompt, system_message=system_message)
# Process all messages at once
inputs = tokenizer.apply_chat_template(
@@ -426,6 +508,68 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
return torch.stack(x_list, dim=0)
def upsample_prompt(
self,
prompt: Union[str, List[str]],
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
temperature: float = 0.15,
device: torch.device = None,
) -> List[str]:
prompt = [prompt] if isinstance(prompt, str) else prompt
device = self.text_encoder.device if device is None else device
# Set system message based on whether images are provided
if images is None or len(images) == 0 or images[0] is None:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
# Validate and process the input images
if images:
images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)
# Format input messages
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
inputs = self.tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=2048,
)
# Move to device
inputs["input_ids"] = inputs["input_ids"].to(device)
inputs["attention_mask"] = inputs["attention_mask"].to(device)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype)
# Generate text using the model's generate method
generated_ids = self.text_encoder.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=temperature,
use_cache=True,
)
# Decode only the newly generated tokens (skip input tokens)
# Extract only the generated portion
input_length = inputs["input_ids"].shape[1]
generated_tokens = generated_ids[:, input_length:]
upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return upsampled_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -620,6 +764,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
caption_upsample_temperature: float = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -635,11 +780,11 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
guidance_scale (`float`, *optional*, defaults to 1.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -684,6 +829,9 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
text_encoder_out_layers (`Tuple[int]`):
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
caption_upsample_temperature (`float`):
When specified, we will try to perform caption upsampling for potentially improved outputs. We
recommend setting it to 0.15 if caption upsampling is to be performed.
Examples:
@@ -718,6 +866,10 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
device = self._execution_device
# 3. prepare text embeddings
if caption_upsample_temperature:
prompt = self.upsample_prompt(
prompt, images=image, temperature=caption_upsample_temperature, device=device
)
prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
@@ -861,6 +1013,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
if output_type == "latent":
image = latents
else:
torch.save({"pred": latents}, "pred_d.pt")
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
@@ -0,0 +1,29 @@
"""
These system prompts come from:
https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54
"""
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
Guidelines:
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
Output only the revised prompt and nothing else."""
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
Rules:
- Single instruction only, no commentary
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
- Specify what changes AND what stays the same (face, lighting, composition)
- Reference actual image elements
- Turn negatives into positives ("don't change X""keep X")
- Make abstractions concrete ("futuristic""glowing cyan neon, metallic panels")
- Keep content PG-13
Output only the final instruction in plain text and nothing else."""
@@ -165,16 +165,21 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
@@ -188,6 +193,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
@@ -199,9 +206,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
assert num_images_per_prompt == 1
device = device or self._execution_device
if prompt_embeds is not None:
@@ -407,6 +417,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
f"Please adjust the width to a multiple of {vae_scale}."
)
assert self.dtype == torch.bfloat16
dtype = self.dtype
device = self._execution_device
self._guidance_scale = guidance_scale
@@ -422,6 +434,10 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
else:
batch_size = len(prompt_embeds)
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -439,8 +455,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
dtype=dtype,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
@@ -456,14 +475,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
generator,
latents,
)
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
# 5. Prepare timesteps
@@ -512,12 +523,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(self.transformer.dtype)
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(self.transformer.dtype)
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
@@ -532,11 +543,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
pos_out = model_out_list[:batch_size]
neg_out = model_out_list[batch_size:]
noise_pred = []
for j in range(actual_batch_size):
for j in range(batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
@@ -577,11 +588,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
latents = latents.to(dtype)
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
@@ -429,22 +429,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -467,10 +452,6 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -401,17 +401,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -819,22 +808,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError("only support log-rho multistep deis now")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -857,10 +831,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -957,21 +927,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -127,17 +127,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -146,14 +147,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
for a small number of steps. It is recommended to use `midpoint` solvers.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
@@ -177,16 +179,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to 1.0):
The shift value for the timestep schedule for flow matching.
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
variance_type (`"learned"` or `"learned_range"`, *optional*):
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
output contains the predicted Gaussian variance.
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
contains the predicted Gaussian variance.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
@@ -195,10 +197,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
use_dynamic_shifting (`bool`, defaults to `False`):
Whether to use dynamic shifting for the timestep schedule.
time_shift_type (`"exponential"`, defaults to `"exponential"`):
The type of time shift to apply when using dynamic shifting.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -210,15 +208,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
@@ -227,14 +225,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_lu_lambdas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[Literal["learned", "learned_range"]] = None,
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: Literal["exponential"] = "exponential",
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -333,22 +331,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
) -> None:
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`, *optional*):
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
mu (`float`, *optional*):
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
`time_shift_type="exponential"`.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
@@ -508,7 +503,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
def _sigma_to_t(self, sigma, log_sigmas):
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -544,18 +539,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
def _sigma_to_alpha_sigma_t(self, sigma):
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -604,21 +588,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
Args:
in_lambdas (`torch.Tensor`):
The input lambda values to be converted.
num_inference_steps (`int`):
The number of inference steps to generate the noise schedule for.
Returns:
`torch.Tensor`:
The converted lambda values following the Lu noise schedule.
"""
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Lu et al. (2022)."""
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
@@ -1098,22 +1069,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1132,13 +1088,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -1153,7 +1105,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
@@ -1163,22 +1115,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int` or `torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`, *optional*):
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`, defaults to `True`):
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
@@ -1258,21 +1210,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -413,17 +413,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -491,17 +491,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -1090,22 +1079,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"Order must be 1, 2, 3, got {order}")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1128,10 +1102,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -1234,21 +1204,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -578,22 +578,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -616,10 +601,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -423,17 +423,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -1114,22 +1103,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1152,10 +1126,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -513,17 +513,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -995,22 +984,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1033,10 +1007,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -1149,21 +1119,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
-306
View File
@@ -1,306 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import unittest
import numpy as np
import torch
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
ZImagePipeline,
ZImageTransformer2DModel,
)
from ...testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
# Cannot use enable_full_determinism() which sets it to True
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = False
# Note: Some tests (test_float16_inference, test_save_load_float16) may fail in full suite
# due to RopeEmbedder cache state pollution between tests. They pass when run individually.
# This is a known test isolation issue, not a functional bug.
class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = ZImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def setUp(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def tearDown(self):
super().tearDown()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = ZImageTransformer2DModel(
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=32,
n_layers=2,
n_refiner_layers=1,
n_heads=2,
n_kv_heads=2,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=16,
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[8, 4, 4],
axes_lens=[256, 32, 32],
)
torch.manual_seed(0)
vae = AutoencoderKL(
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
block_out_channels=[32, 64],
layers_per_block=1,
latent_channels=16,
norm_num_groups=32,
sample_size=32,
scaling_factor=0.3611,
shift_factor=0.1159,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
config = Qwen3Config(
hidden_size=16,
intermediate_size=16,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
vocab_size=151936,
max_position_embeddings=512,
)
text_encoder = Qwen3Model(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"cfg_normalization": False,
"cfg_truncation": 1.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
# fmt: on
generated_slice = generated_image.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=5e-2))
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
def test_num_images_per_prompt(self):
import inspect
sig = inspect.signature(self.pipeline_class.__call__)
if "num_images_per_prompt" not in sig.parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
assert images.shape[0] == batch_size * num_images_per_prompt
del pipe
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling (standard AutoencoderKL doesn't accept parameters)
pipe.vae.enable_tiling()
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4):
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
def test_group_offloading_inference(self):
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
def test_save_load_float16(self, expected_max_diff=1e-2):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
super().test_save_load_float16(expected_max_diff=expected_max_diff)