Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6bf668c4d2 | |||
| e6d4612309 | |||
| a88a7b4f03 | |||
| c8656ed73c |
@@ -26,12 +26,6 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
|
||||
>
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## Caption upsampling
|
||||
|
||||
Flux.2 can potentially generate better better outputs with better prompts. We can "upsample"
|
||||
an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.
|
||||
The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.
|
||||
|
||||
## Flux2Pipeline
|
||||
|
||||
[[autodoc]] Flux2Pipeline
|
||||
|
||||
@@ -69,7 +69,10 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
weight_dtype = self.mlp[0].weight.dtype
|
||||
if weight_dtype.is_floating_point:
|
||||
t_freq = t_freq.to(weight_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
@@ -126,6 +129,10 @@ class ZSingleStreamAttnProcessor:
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
@@ -306,6 +313,10 @@ class RopeEmbedder:
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
else:
|
||||
# Ensure freqs_cis are on the same device as ids
|
||||
if self.freqs_cis[0].device != device:
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
@@ -317,6 +328,7 @@ class RopeEmbedder:
|
||||
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -553,8 +565,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
@@ -572,6 +582,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import PIL.Image
|
||||
|
||||
@@ -96,9 +96,9 @@ class Flux2ImageProcessor(VaeImageProcessor):
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
|
||||
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
|
||||
image_width, image_height = image.size
|
||||
|
||||
scale = math.sqrt(target_area / (image_width * image_height))
|
||||
@@ -106,14 +106,6 @@ class Flux2ImageProcessor(VaeImageProcessor):
|
||||
height = int(image_height * scale)
|
||||
|
||||
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
||||
|
||||
@staticmethod
|
||||
def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
|
||||
image_width, image_height = image.size
|
||||
pixel_count = image_width * image_height
|
||||
if pixel_count <= target_area:
|
||||
return image
|
||||
return Flux2ImageProcessor._resize_to_target_area(image, target_area)
|
||||
|
||||
def _resize_and_crop(
|
||||
self,
|
||||
@@ -144,35 +136,3 @@ class Flux2ImageProcessor(VaeImageProcessor):
|
||||
bottom = top + height
|
||||
|
||||
return image.crop((left, top, right, bottom))
|
||||
|
||||
# Taken from
|
||||
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19
|
||||
@staticmethod
|
||||
def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Concatenate a list of PIL images horizontally with center alignment and white background.
|
||||
"""
|
||||
|
||||
# If only one image, return a copy of it
|
||||
if len(images) == 1:
|
||||
return images[0].copy()
|
||||
|
||||
# Convert all images to RGB if not already
|
||||
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
|
||||
|
||||
# Calculate dimensions for horizontal concatenation
|
||||
total_width = sum(img.width for img in images)
|
||||
max_height = max(img.height for img in images)
|
||||
|
||||
# Create new image with white background
|
||||
background_color = (255, 255, 255)
|
||||
new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)
|
||||
|
||||
# Paste images with center alignment
|
||||
x_offset = 0
|
||||
for img in images:
|
||||
y_offset = (max_height - img.height) // 2
|
||||
new_img.paste(img, (x_offset, y_offset))
|
||||
x_offset += img.width
|
||||
|
||||
return new_img
|
||||
|
||||
@@ -28,7 +28,6 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -57,107 +56,25 @@ EXAMPLE_DOC_STRING = """
|
||||
```
|
||||
"""
|
||||
|
||||
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
|
||||
def format_input(
|
||||
prompts: List[str],
|
||||
system_message: str = SYSTEM_MESSAGE,
|
||||
images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
|
||||
):
|
||||
"""
|
||||
Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
|
||||
to the input.
|
||||
|
||||
Args:
|
||||
prompts: List of text prompts
|
||||
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
|
||||
images (optional): List of images to add to the input.
|
||||
|
||||
Returns:
|
||||
List of conversations, where each conversation is a list of message dicts
|
||||
"""
|
||||
def format_text_input(prompts: List[str], system_message: str = None):
|
||||
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
||||
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
||||
# if the count changes after truncation.
|
||||
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||
|
||||
if images is None or len(images) == 0:
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
for prompt in cleaned_txt
|
||||
]
|
||||
else:
|
||||
assert len(images) == len(prompts), "Number of images must match number of prompts"
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
]
|
||||
for _ in cleaned_txt
|
||||
]
|
||||
|
||||
for i, (el, images) in enumerate(zip(messages, images)):
|
||||
# optionally add the images per batch element.
|
||||
if images is not None:
|
||||
el.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image_obj} for image_obj in images],
|
||||
}
|
||||
)
|
||||
# add the text.
|
||||
el.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": cleaned_txt[i]}],
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19
|
||||
def _validate_and_process_images(
|
||||
images: List[List[PIL.Image.Image]] | List[PIL.Image.Image],
|
||||
image_processor: Flux2ImageProcessor,
|
||||
upsampling_max_image_size: int,
|
||||
) -> List[List[PIL.Image.Image]]:
|
||||
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
|
||||
if not images:
|
||||
return []
|
||||
|
||||
# Check if it's a list of lists or a list of images
|
||||
if isinstance(images[0], PIL.Image.Image):
|
||||
# It's a list of images, convert to list of lists
|
||||
images = [[im] for im in images]
|
||||
|
||||
# potentially concatenate multiple images to reduce the size
|
||||
images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images]
|
||||
|
||||
# cap the pixels
|
||||
images = [
|
||||
return [
|
||||
[
|
||||
image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size)
|
||||
for img_i in img_i
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
for img_i in images
|
||||
for prompt in cleaned_txt
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
# Taken from
|
||||
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
@@ -297,10 +214,9 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
self.system_message = SYSTEM_MESSAGE
|
||||
self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
|
||||
self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
|
||||
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
|
||||
# fmt: off
|
||||
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def _get_mistral_3_small_prompt_embeds(
|
||||
@@ -310,7 +226,9 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
system_message: str = SYSTEM_MESSAGE,
|
||||
# fmt: off
|
||||
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||
# fmt: on
|
||||
hidden_states_layers: List[int] = (10, 20, 30),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
@@ -319,7 +237,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Format input messages
|
||||
messages_batch = format_input(prompts=prompt, system_message=system_message)
|
||||
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
||||
|
||||
# Process all messages at once
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
@@ -508,68 +426,6 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def upsample_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
|
||||
temperature: float = 0.15,
|
||||
device: torch.device = None,
|
||||
) -> List[str]:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
device = self.text_encoder.device if device is None else device
|
||||
|
||||
# Set system message based on whether images are provided
|
||||
if images is None or len(images) == 0 or images[0] is None:
|
||||
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
|
||||
else:
|
||||
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
|
||||
|
||||
# Validate and process the input images
|
||||
if images:
|
||||
images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)
|
||||
|
||||
# Format input messages
|
||||
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
|
||||
|
||||
# Process all messages at once
|
||||
# with image processing a too short max length can throw an error in here.
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
messages_batch,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
)
|
||||
|
||||
# Move to device
|
||||
inputs["input_ids"] = inputs["input_ids"].to(device)
|
||||
inputs["attention_mask"] = inputs["attention_mask"].to(device)
|
||||
|
||||
if "pixel_values" in inputs:
|
||||
inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype)
|
||||
|
||||
# Generate text using the model's generate method
|
||||
generated_ids = self.text_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Decode only the newly generated tokens (skip input tokens)
|
||||
# Extract only the generated portion
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
generated_tokens = generated_ids[:, input_length:]
|
||||
|
||||
upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
|
||||
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return upsampled_prompt
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -764,7 +620,6 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
||||
caption_upsample_temperature: float = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -780,11 +635,11 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
guidance_scale (`float`, *optional*, defaults to 1.0):
|
||||
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
||||
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
||||
|
||||
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
||||
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
@@ -829,9 +684,6 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
text_encoder_out_layers (`Tuple[int]`):
|
||||
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
|
||||
caption_upsample_temperature (`float`):
|
||||
When specified, we will try to perform caption upsampling for potentially improved outputs. We
|
||||
recommend setting it to 0.15 if caption upsampling is to be performed.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -866,10 +718,6 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
if caption_upsample_temperature:
|
||||
prompt = self.upsample_prompt(
|
||||
prompt, images=image, temperature=caption_upsample_temperature, device=device
|
||||
)
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
@@ -1013,7 +861,6 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
torch.save({"pred": latents}, "pred_d.pt")
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
These system prompts come from:
|
||||
https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54
|
||||
"""
|
||||
|
||||
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
|
||||
attribution and actions without speculation."""
|
||||
|
||||
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
|
||||
|
||||
Guidelines:
|
||||
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
|
||||
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
|
||||
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
|
||||
|
||||
Output only the revised prompt and nothing else."""
|
||||
|
||||
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
|
||||
|
||||
Rules:
|
||||
- Single instruction only, no commentary
|
||||
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
|
||||
- Specify what changes AND what stays the same (face, lighting, composition)
|
||||
- Reference actual image elements
|
||||
- Turn negatives into positives ("don't change X" → "keep X")
|
||||
- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
|
||||
- Keep content PG-13
|
||||
|
||||
Output only the final instruction in plain text and nothing else."""
|
||||
@@ -165,21 +165,16 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -193,8 +188,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -206,12 +199,9 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
assert num_images_per_prompt == 1
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
@@ -417,8 +407,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
assert self.dtype == torch.bfloat16
|
||||
dtype = self.dtype
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -434,10 +422,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@@ -455,11 +439,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -475,6 +456,14 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
@@ -523,12 +512,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
@@ -543,11 +532,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:batch_size]
|
||||
neg_out = model_out_list[batch_size:]
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(batch_size):
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
@@ -588,11 +577,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = latents.to(dtype)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
||||
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
|
||||
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
@@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++"`.
|
||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
||||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
||||
paper, and the `dpmsolver++` type implements the algorithms in the
|
||||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
||||
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||
solver_type (`str`, defaults to `midpoint`):
|
||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
|
||||
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
|
||||
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
|
||||
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
|
||||
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
|
||||
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
|
||||
for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
lower_order_final (`bool`, defaults to `True`):
|
||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||
@@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||
The shift value for the timestep schedule for flow matching.
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
variance_type (`str`, *optional*):
|
||||
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
||||
contains the predicted Gaussian variance.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
variance_type (`"learned"` or `"learned_range"`, *optional*):
|
||||
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
|
||||
output contains the predicted Gaussian variance.
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
@@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
use_dynamic_shifting (`bool`, defaults to `False`):
|
||||
Whether to use dynamic shifting for the timestep schedule.
|
||||
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
||||
The type of time shift to apply when using dynamic shifting.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
@@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_lu_lambdas: Optional[bool] = False,
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
mu: Optional[float] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
mu (`float`, *optional*):
|
||||
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
|
||||
`time_shift_type="exponential"`.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
||||
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
||||
@@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Lu et al. (2022)."""
|
||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
|
||||
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
|
||||
|
||||
Args:
|
||||
in_lambdas (`torch.Tensor`):
|
||||
The input lambda values to be converted.
|
||||
num_inference_steps (`int`):
|
||||
The number of inference steps to generate the noise schedule for.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted lambda values following the Lu noise schedule.
|
||||
"""
|
||||
|
||||
lambda_min: float = in_lambdas[-1].item()
|
||||
lambda_max: float = in_lambdas[0].item()
|
||||
@@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return step_index
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.Tensor`):
|
||||
variance_noise (`torch.Tensor`, *optional*):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`LEdits++`].
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
@@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
|
||||
@@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImagePipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
# Cannot use enable_full_determinism() which sets it to True
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
# Note: Some tests (test_float16_inference, test_save_load_float16) may fail in full suite
|
||||
# due to RopeEmbedder cache state pollution between tests. They pass when run individually.
|
||||
# This is a known test isolation issue, not a functional bug.
|
||||
|
||||
|
||||
class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ZImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = ZImageTransformer2DModel(
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=32,
|
||||
n_layers=2,
|
||||
n_refiner_layers=1,
|
||||
n_heads=2,
|
||||
n_kv_heads=2,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=16,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[8, 4, 4],
|
||||
axes_lens=[256, 32, 32],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
block_out_channels=[32, 64],
|
||||
layers_per_block=1,
|
||||
latent_channels=16,
|
||||
norm_num_groups=32,
|
||||
sample_size=32,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3Config(
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
text_encoder = Qwen3Model(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"cfg_normalization": False,
|
||||
"cfg_truncation": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=5e-2))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
if "num_images_per_prompt" not in sig.parameters:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling (standard AutoencoderKL doesn't accept parameters)
|
||||
pipe.vae.enable_tiling()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4):
|
||||
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
|
||||
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_group_offloading_inference(self):
|
||||
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
|
||||
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
|
||||
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
super().test_save_load_float16(expected_max_diff=expected_max_diff)
|
||||
Reference in New Issue
Block a user