Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6e539e9505 | |||
| 6e15e47422 | |||
| d4b6f6c1cc | |||
| 169d45ca35 | |||
| 5c669f8798 | |||
| 0f599d9901 |
@@ -839,9 +839,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
|
||||
@@ -725,9 +725,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -747,9 +747,9 @@ class TokenEmbeddingsHandler:
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
self.embeddings_settings[
|
||||
f"original_embeddings_{idx}"
|
||||
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
|
||||
@@ -890,9 +890,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -912,9 +912,9 @@ class TokenEmbeddingsHandler:
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
self.embeddings_settings[
|
||||
f"original_embeddings_{idx}"
|
||||
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
|
||||
@@ -907,12 +907,12 @@ def create_controller(
|
||||
|
||||
# reweight
|
||||
if edit_type == "reweight":
|
||||
assert equalizer_words is not None and equalizer_strengths is not None, (
|
||||
"To use reweight edit, please specify equalizer_words and equalizer_strengths."
|
||||
)
|
||||
assert len(equalizer_words) == len(equalizer_strengths), (
|
||||
"equalizer_words and equalizer_strengths must be of same length."
|
||||
)
|
||||
assert (
|
||||
equalizer_words is not None and equalizer_strengths is not None
|
||||
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
|
||||
assert len(equalizer_words) == len(
|
||||
equalizer_strengths
|
||||
), "equalizer_words and equalizer_strengths must be of same length."
|
||||
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
|
||||
return AttentionReweight(
|
||||
prompts,
|
||||
|
||||
@@ -731,18 +731,18 @@ def main(args):
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
if args.real_prior:
|
||||
assert (class_images_dir / "images").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (class_images_dir / "caption.txt").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (class_images_dir / "images.txt").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (
|
||||
class_images_dir / "images"
|
||||
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
assert (
|
||||
len(list((class_images_dir / "images").iterdir())) == args.num_class_images
|
||||
), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
assert (
|
||||
class_images_dir / "caption.txt"
|
||||
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
assert (
|
||||
class_images_dir / "images.txt"
|
||||
).exists(), f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
|
||||
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
|
||||
args.concepts_list[i] = concept
|
||||
|
||||
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
assert pipeline.transformer.config.in_channels == initial_channels * 2, (
|
||||
f"{pipeline.transformer.config.in_channels=}"
|
||||
)
|
||||
assert (
|
||||
pipeline.transformer.config.in_channels == initial_channels * 2
|
||||
), f"{pipeline.transformer.config.in_channels=}"
|
||||
|
||||
pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -1081,9 +1081,9 @@ class AutoConfig:
|
||||
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
|
||||
)
|
||||
|
||||
pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
|
||||
textual_inversion_path.model_path
|
||||
)
|
||||
pretrained_model_name_or_paths[
|
||||
pretrained_model_name_or_paths.index(search_word)
|
||||
] = textual_inversion_path.model_path
|
||||
|
||||
self.load_textual_inversion(
|
||||
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
|
||||
|
||||
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"]
|
||||
assert torch.count_nonzero(tokens - 49407) == 2, (
|
||||
f"String '{string}' maps to more than a single token. Please use another string"
|
||||
)
|
||||
assert (
|
||||
torch.count_nonzero(tokens - 49407) == 2
|
||||
), f"String '{string}' maps to more than a single token. Please use another string"
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
|
||||
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], (
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
)
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
+6
-6
@@ -763,9 +763,9 @@ def main(args):
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
|
||||
"Instance data dir and prompt inputs are not of the same length."
|
||||
)
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
@@ -788,9 +788,9 @@ def main(args):
|
||||
negative_validation_prompts.append(None)
|
||||
args.validation_negative_prompt = negative_validation_prompts
|
||||
|
||||
assert num_of_validation_prompts == len(negative_validation_prompts), (
|
||||
"The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
)
|
||||
assert num_of_validation_prompts == len(
|
||||
negative_validation_prompts
|
||||
), "The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
|
||||
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
|
||||
|
||||
|
||||
@@ -830,9 +830,9 @@ def main():
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = get_mask(tokenizer, accelerator)
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -886,9 +886,9 @@ def main():
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -910,9 +910,9 @@ def main():
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -965,12 +965,12 @@ def main():
|
||||
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
|
||||
orig_embeds_params_2[index_no_updates_2]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
|
||||
index_no_updates_2
|
||||
] = orig_embeds_params_2[index_no_updates_2]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -653,15 +653,15 @@ def main():
|
||||
try:
|
||||
# Gets the resolution of the timm transformation after centercrop
|
||||
timm_centercrop_transform = timm_transform.transforms[1]
|
||||
assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
|
||||
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
)
|
||||
assert isinstance(
|
||||
timm_centercrop_transform, transforms.CenterCrop
|
||||
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
timm_model_resolution = timm_centercrop_transform.size[0]
|
||||
# Gets final normalization
|
||||
timm_model_normalization = timm_transform.transforms[-1]
|
||||
assert isinstance(timm_model_normalization, transforms.Normalize), (
|
||||
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
)
|
||||
assert isinstance(
|
||||
timm_model_normalization, transforms.Normalize
|
||||
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
except AssertionError as e:
|
||||
raise NotImplementedError(e)
|
||||
# Enable flash attention if asked
|
||||
|
||||
@@ -261,9 +261,9 @@ def main(args):
|
||||
|
||||
model_name = args.model_path.split("/")[-1].split(".")[0]
|
||||
if not os.path.isfile(args.model_path):
|
||||
assert model_name == args.model_path, (
|
||||
f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
)
|
||||
assert (
|
||||
model_name == args.model_path
|
||||
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
args.model_path = download(model_name)
|
||||
|
||||
sample_rate = MODELS_MAP[model_name]["sample_rate"]
|
||||
@@ -290,9 +290,9 @@ def main(args):
|
||||
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
|
||||
|
||||
for key, value in renamed_state_dict.items():
|
||||
assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
|
||||
f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
)
|
||||
assert (
|
||||
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
|
||||
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
if key == "time_proj.weight":
|
||||
value = value.squeeze()
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ def main(args):
|
||||
model_config = HunyuanDiT2DControlNetModel.load_config(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
||||
)
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
print(model_config)
|
||||
|
||||
for key in state_dict:
|
||||
|
||||
@@ -18,9 +18,9 @@ def main(args):
|
||||
|
||||
device = "cuda"
|
||||
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
|
||||
# input_size -> sample_size, text_dim -> cross_attention_dim
|
||||
for key in state_dict:
|
||||
|
||||
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
|
||||
|
||||
block_out_channels = original_config["channels"]
|
||||
|
||||
assert len(set(original_config["depths"])) == 1, (
|
||||
"UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
)
|
||||
assert (
|
||||
len(set(original_config["depths"])) == 1
|
||||
), "UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
layers_per_block = original_config["depths"][0]
|
||||
|
||||
class_labels_dim = original_config["mapping_cond_dim"]
|
||||
|
||||
@@ -305,9 +305,9 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
|
||||
for i in range(down_block_layers[block]):
|
||||
# Convert resnets
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
|
||||
)
|
||||
@@ -317,9 +317,9 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
|
||||
)
|
||||
|
||||
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
# TODO resnet time_mixer.mix_factor
|
||||
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
new_checkpoint[
|
||||
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
|
||||
)
|
||||
|
||||
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
new_checkpoint[
|
||||
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
|
||||
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert original_config["target"] in PORTED_VQVAES, (
|
||||
f"{original_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert (
|
||||
original_config["target"] in PORTED_VQVAES
|
||||
), f"{original_config['target']} has not yet been ported to diffusers."
|
||||
|
||||
original_config = original_config["params"]
|
||||
|
||||
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
|
||||
f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
|
||||
f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
|
||||
f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert (
|
||||
original_diffusion_config["target"] in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config["target"] in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
|
||||
original_diffusion_config = original_diffusion_config["params"]
|
||||
original_transformer_config = original_transformer_config["params"]
|
||||
|
||||
@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_unet:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Handle text encoder LoRAs.
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Store alpha if present.
|
||||
if lora_name_alpha in state_dict:
|
||||
@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
||||
@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in original_state_dict)
|
||||
if has_guidance:
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
||||
|
||||
@@ -211,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
|
||||
def _init_vectorized_inputs(self, norm_type):
|
||||
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert self.config.num_vector_embeds is not None, (
|
||||
"Transformer2DModel over discrete input must provide num_embed"
|
||||
)
|
||||
assert (
|
||||
self.config.num_vector_embeds is not None
|
||||
), "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = self.config.sample_size
|
||||
self.width = self.config.sample_size
|
||||
|
||||
@@ -341,9 +341,9 @@ class AnimateDiffFreeNoiseMixin:
|
||||
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
|
||||
self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
)
|
||||
negative_prompt_interpolation_embeds[
|
||||
start_frame : end_frame + 1
|
||||
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
|
||||
prompt_embeds = prompt_interpolation_embeds
|
||||
negative_prompt_embeds = negative_prompt_interpolation_embeds
|
||||
|
||||
@@ -24,6 +24,7 @@ from transformers import (
|
||||
CLIPTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaProcessor,
|
||||
)
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -100,6 +101,50 @@ DEFAULT_PROMPT_TEMPLATE = {
|
||||
}
|
||||
|
||||
|
||||
def _expand_input_ids_with_image_tokens(
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
max_sequence_length,
|
||||
image_token_index,
|
||||
image_emb_len,
|
||||
image_emb_start,
|
||||
image_emb_end,
|
||||
pad_token_id,
|
||||
):
|
||||
special_image_token_mask = text_input_ids == image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
|
||||
|
||||
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
expanded_input_ids = torch.full(
|
||||
(text_input_ids.shape[0], max_expanded_length),
|
||||
pad_token_id,
|
||||
dtype=text_input_ids.dtype,
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
|
||||
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
|
||||
|
||||
expanded_attention_mask = torch.zeros(
|
||||
(text_input_ids.shape[0], max_expanded_length),
|
||||
dtype=prompt_attention_mask.dtype,
|
||||
device=prompt_attention_mask.device,
|
||||
)
|
||||
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
|
||||
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
|
||||
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
|
||||
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
|
||||
|
||||
return {
|
||||
"input_ids": expanded_input_ids,
|
||||
"attention_mask": expanded_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -231,6 +276,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.llava_processor = LlavaProcessor(
|
||||
self.image_processor,
|
||||
self.tokenizer,
|
||||
patch_size=self.text_encoder.config.vision_config.patch_size,
|
||||
vision_feature_select_strategy=self.text_encoder.config.vision_feature_select_strategy,
|
||||
num_additional_image_tokens=1,
|
||||
)
|
||||
|
||||
def _get_llama_prompt_embeds(
|
||||
self,
|
||||
@@ -251,6 +303,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt = [prompt_template["template"].format(p) for p in prompt]
|
||||
|
||||
crop_start = prompt_template.get("crop_start", None)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
image_emb_start = prompt_template.get("image_emb_start", 5)
|
||||
image_emb_end = prompt_template.get("image_emb_end", 581)
|
||||
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
||||
|
||||
if crop_start is None:
|
||||
prompt_template_input = self.tokenizer(
|
||||
prompt_template["template"],
|
||||
@@ -280,19 +338,25 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
|
||||
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
|
||||
|
||||
image_token_index = self.text_encoder.config.image_token_index
|
||||
pad_token_id = self.text_encoder.config.pad_token_id
|
||||
expanded_inputs = _expand_input_ids_with_image_tokens(
|
||||
text_input_ids,
|
||||
prompt_attention_mask,
|
||||
max_sequence_length,
|
||||
image_token_index,
|
||||
image_emb_len,
|
||||
image_emb_start,
|
||||
image_emb_end,
|
||||
pad_token_id,
|
||||
)
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
pixel_values=image_embeds,
|
||||
**expanded_inputs,
|
||||
pixel_value=image_embeds,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
image_emb_len = prompt_template.get("image_emb_len", 576)
|
||||
image_emb_start = prompt_template.get("image_emb_start", 5)
|
||||
image_emb_end = prompt_template.get("image_emb_end", 581)
|
||||
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
text_crop_start = crop_start - 1 + image_emb_len
|
||||
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
|
||||
|
||||
@@ -95,13 +95,13 @@ class OmniGenMultiModalProcessor:
|
||||
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
||||
|
||||
unique_image_ids = sorted(set(image_ids))
|
||||
assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
|
||||
f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
||||
)
|
||||
assert unique_image_ids == list(
|
||||
range(1, len(unique_image_ids) + 1)
|
||||
), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
||||
# total images must be the same as the number of image tags
|
||||
assert len(unique_image_ids) == len(input_images), (
|
||||
f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
||||
)
|
||||
assert (
|
||||
len(unique_image_ids) == len(input_images)
|
||||
), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
||||
|
||||
input_images = [input_images[x - 1] for x in image_ids]
|
||||
|
||||
|
||||
@@ -983,9 +983,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
fields = torch.cat(fields, dim=1)
|
||||
fields = fields.float()
|
||||
|
||||
assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
|
||||
f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
|
||||
)
|
||||
assert (
|
||||
len(fields.shape) == 3 and fields.shape[-1] == 1
|
||||
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
|
||||
|
||||
fields = fields.reshape(1, *([grid_size] * 3))
|
||||
|
||||
@@ -1039,9 +1039,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
textures = textures.float()
|
||||
|
||||
# 3.3 augument the mesh with texture data
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
|
||||
f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
)
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(
|
||||
texture_channels
|
||||
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
|
||||
for m, texture in zip(raw_meshes, textures):
|
||||
texture = texture[: len(m.verts)]
|
||||
|
||||
@@ -215,15 +215,19 @@ class DiffusersQuantizer(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _process_model_before_weight_loading(self, model, **kwargs): ...
|
||||
def _process_model_before_weight_loading(self, model, **kwargs):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _process_model_after_weight_loading(self, model, **kwargs): ...
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_serializable(self): ...
|
||||
def is_serializable(self):
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_trainable(self): ...
|
||||
def is_trainable(self):
|
||||
...
|
||||
|
||||
@@ -299,9 +299,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 3, (
|
||||
"3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
)
|
||||
assert (
|
||||
download_requests.count("HEAD") == 3
|
||||
), "3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
@@ -313,9 +313,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
)
|
||||
assert (
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
|
||||
@@ -92,9 +92,9 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
|
||||
"xformers is not enabled"
|
||||
)
|
||||
assert (
|
||||
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
@@ -167,9 +167,9 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
|
||||
"xformers is not enabled"
|
||||
)
|
||||
assert (
|
||||
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
|
||||
@@ -654,22 +654,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
|
||||
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
|
||||
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
|
||||
assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
|
||||
"a 'keep all' mask should give the same result as no mask"
|
||||
)
|
||||
assert full_cond_keepallmask_out.allclose(
|
||||
full_cond_out, rtol=1e-05, atol=1e-05
|
||||
), "a 'keep all' mask should give the same result as no mask"
|
||||
|
||||
trunc_cond = cond[:, :-1, :]
|
||||
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
|
||||
assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
|
||||
"discarding the last token from our cond should change the result"
|
||||
)
|
||||
assert not trunc_cond_out.allclose(
|
||||
full_cond_out, rtol=1e-05, atol=1e-05
|
||||
), "discarding the last token from our cond should change the result"
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
|
||||
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
|
||||
assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
|
||||
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
)
|
||||
assert masked_cond_out.allclose(
|
||||
trunc_cond_out, rtol=1e-05, atol=1e-05
|
||||
), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
@@ -697,9 +697,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
assert trunc_mask_out.allclose(
|
||||
keeplast_out
|
||||
), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
@@ -1114,12 +1114,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
assert not torch.allclose(
|
||||
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
|
||||
), "LoRA injected UNet should produce different results."
|
||||
assert torch.allclose(
|
||||
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
|
||||
), "Loading from a saved checkpoint should produce identical results."
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
|
||||
@@ -65,9 +65,9 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
)
|
||||
out_np = self.to_np(out)
|
||||
in_np = (input_np * 255).round() if output_type == "pil" else input_np
|
||||
assert np.abs(in_np - out_np).max() < 1e-6, (
|
||||
f"decoded output does not match input for output_type {output_type}"
|
||||
)
|
||||
assert (
|
||||
np.abs(in_np - out_np).max() < 1e-6
|
||||
), f"decoded output does not match input for output_type {output_type}"
|
||||
|
||||
def test_vae_image_processor_np(self):
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
|
||||
@@ -78,9 +78,9 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
|
||||
out_np = self.to_np(out)
|
||||
in_np = (input_np * 255).round() if output_type == "pil" else input_np
|
||||
assert np.abs(in_np - out_np).max() < 1e-6, (
|
||||
f"decoded output does not match input for output_type {output_type}"
|
||||
)
|
||||
assert (
|
||||
np.abs(in_np - out_np).max() < 1e-6
|
||||
), f"decoded output does not match input for output_type {output_type}"
|
||||
|
||||
def test_vae_image_processor_pil(self):
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
|
||||
@@ -93,9 +93,9 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
for i, o in zip(input_pil, out):
|
||||
in_np = np.array(i)
|
||||
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
|
||||
assert np.abs(in_np - out_np).max() < 1e-6, (
|
||||
f"decoded output does not match input for output_type {output_type}"
|
||||
)
|
||||
assert (
|
||||
np.abs(in_np - out_np).max() < 1e-6
|
||||
), f"decoded output does not match input for output_type {output_type}"
|
||||
|
||||
def test_preprocess_input_3d(self):
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||
@@ -293,9 +293,9 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
scale = 2
|
||||
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
|
||||
exp_pt_shape = (b, c, h // scale, w // scale)
|
||||
assert out_pt.shape == exp_pt_shape, (
|
||||
f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
|
||||
)
|
||||
assert (
|
||||
out_pt.shape == exp_pt_shape
|
||||
), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
|
||||
|
||||
def test_vae_image_processor_resize_np(self):
|
||||
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
|
||||
@@ -305,6 +305,6 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
input_np = self.to_np(input_pt)
|
||||
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
|
||||
exp_np_shape = (b, h // scale, w // scale, c)
|
||||
assert out_np.shape == exp_np_shape, (
|
||||
f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
|
||||
)
|
||||
assert (
|
||||
out_np.shape == exp_np_shape
|
||||
), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
|
||||
|
||||
@@ -126,7 +126,8 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
|
||||
|
||||
@unittest.skip("aMUSEd does not support lists of generators")
|
||||
def test_inference_batch_single_identical(self): ...
|
||||
def test_inference_batch_single_identical(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -126,7 +126,8 @@ class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
|
||||
|
||||
@unittest.skip("aMUSEd does not support lists of generators")
|
||||
def test_inference_batch_single_identical(self): ...
|
||||
def test_inference_batch_single_identical(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -130,7 +130,8 @@ class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
|
||||
|
||||
@unittest.skip("aMUSEd does not support lists of generators")
|
||||
def test_inference_batch_single_identical(self): ...
|
||||
def test_inference_batch_single_identical(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -106,9 +106,9 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -122,15 +122,15 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
@unittest.skip("xformers attention processor does not exist for AuraFlow")
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
|
||||
@@ -195,9 +195,9 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
[0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
|
||||
|
||||
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
|
||||
@@ -299,9 +299,9 @@ class CogVideoXPipelineFastTests(
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -315,15 +315,15 @@ class CogVideoXPipelineFastTests(
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -299,9 +299,9 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -315,12 +315,12 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
@@ -317,9 +317,9 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -333,15 +333,15 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -298,9 +298,9 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -314,12 +314,12 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
@@ -219,9 +219,9 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||
assert image.shape == (1, 16, 16, 4)
|
||||
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
|
||||
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
|
||||
@@ -178,9 +178,9 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
|
||||
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
|
||||
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
|
||||
@@ -170,9 +170,9 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -186,15 +186,15 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
@@ -162,9 +162,9 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
|
||||
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
|
||||
@@ -194,9 +194,9 @@ class StableDiffusion3ControlInpaintNetPipelineFastTests(unittest.TestCase, Pipe
|
||||
[0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
|
||||
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
|
||||
@@ -202,9 +202,9 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
else:
|
||||
expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
|
||||
|
||||
def test_controlnet_sd3(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -170,9 +170,9 @@ class FluxPipelineFastTests(
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -186,15 +186,15 @@ class FluxPipelineFastTests(
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
@@ -140,9 +140,9 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -156,15 +156,15 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
@@ -134,9 +134,9 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -150,15 +150,15 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
@@ -174,9 +174,9 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -192,15 +192,15 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_disabled = pipe(**inputs)[0]
|
||||
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
@unittest.skip(
|
||||
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
|
||||
|
||||
@@ -240,12 +240,12 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
expected_slice = np.array([1.0000, 1.0000, 0.2766, 1.0000, 0.5447, 0.1737, 1.0000, 0.4316, 0.9024])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
|
||||
@@ -98,12 +98,12 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
expected_slice = np.array([0.2893, 0.1464, 0.4603, 0.3529, 0.4612, 0.7701, 0.4027, 0.3051, 0.5155])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
@@ -206,12 +206,12 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
expected_slice = np.array([0.4852, 0.4136, 0.4539, 0.4781, 0.4680, 0.5217, 0.4973, 0.4089, 0.4977])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
@@ -318,12 +318,12 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
|
||||
@@ -261,12 +261,12 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.5816, 0.5872, 0.4634, 0.5982, 0.4767, 0.4710, 0.4669, 0.4717, 0.4966])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
|
||||
@@ -256,12 +256,12 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
expected_slice = np.array([0.8222, 0.8896, 0.4373, 0.8088, 0.4905, 0.2609, 0.6816, 0.4291, 0.5129])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
@@ -210,13 +210,13 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
expected_slice = np.array([0.3420, 0.9505, 0.3919, 1.0000, 0.5188, 0.3109, 0.6139, 0.5624, 0.6811])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=1e-1)
|
||||
|
||||
@@ -103,12 +103,12 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
|
||||
expected_slice = np.array([0.3076, 0.2729, 0.5668, 0.0522, 0.3384, 0.7028, 0.4908, 0.3659, 0.6243])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
@@ -227,12 +227,12 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
|
||||
|
||||
expected_slice = np.array([0.4445, 0.4287, 0.4596, 0.3919, 0.3730, 0.5039, 0.4834, 0.4269, 0.5521])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
@@ -350,12 +350,12 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
|
||||
|
||||
expected_slice = np.array([0.5039, 0.4926, 0.4898, 0.4978, 0.4838, 0.4942, 0.4738, 0.4702, 0.4816])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
|
||||
@@ -210,13 +210,13 @@ class KandinskyV22ControlnetPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
[0.6959826, 0.868279, 0.7558092, 0.68769467, 0.85805804, 0.65977496, 0.44885302, 0.5959111, 0.4251595]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=1e-1)
|
||||
|
||||
@@ -218,12 +218,12 @@ class KandinskyV22ControlnetImg2ImgPipelineFastTests(PipelineTesterMixin, unitte
|
||||
expected_slice = np.array(
|
||||
[0.54985034, 0.55509365, 0.52561504, 0.5570494, 0.5593818, 0.5263979, 0.50285643, 0.5069846, 0.51196736]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=1.75e-3)
|
||||
|
||||
@@ -228,12 +228,12 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.5712, 0.5443, 0.4725, 0.6195, 0.5184, 0.4651, 0.4473, 0.4590, 0.5016])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=2e-1)
|
||||
|
||||
@@ -234,12 +234,12 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
[0.50775903, 0.49527195, 0.48824543, 0.50192237, 0.48644906, 0.49373814, 0.4780598, 0.47234827, 0.48327848]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
@@ -157,9 +157,9 @@ class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=1e-1)
|
||||
|
||||
@@ -181,9 +181,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
[0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=1e-1)
|
||||
|
||||
@@ -450,9 +450,9 @@ class AnimateDiffPAGPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -169,9 +169,9 @@ class StableDiffusionControlNetPAGPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
@@ -165,9 +165,9 @@ class StableDiffusionControlNetPAGInpaintPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
@@ -187,9 +187,9 @@ class StableDiffusionXLControlNetPAGPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
@@ -189,9 +189,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
@@ -177,15 +177,15 @@ class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_disabled = pipe(**inputs)[0]
|
||||
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pag_disable_enable(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -198,9 +198,9 @@ class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -140,9 +140,9 @@ class KolorsPAGPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
@@ -120,9 +120,9 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
|
||||
out = pipe(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
@@ -268,9 +268,9 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -154,9 +154,9 @@ class StableDiffusionPAGPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
@@ -328,9 +328,9 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
|
||||
@@ -345,6 +345,6 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -170,9 +170,9 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -186,15 +186,15 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pag_disable_enable(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -207,9 +207,9 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -149,9 +149,9 @@ class StableDiffusion3PAGImg2ImgPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
@@ -254,9 +254,9 @@ class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
0.17822266,
|
||||
]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
@@ -272,6 +272,6 @@ class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -161,9 +161,9 @@ class StableDiffusionPAGImg2ImgPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
@@ -267,9 +267,9 @@ class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
|
||||
@@ -285,6 +285,6 @@ class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -302,9 +302,9 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
|
||||
@@ -319,6 +319,6 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -167,9 +167,9 @@ class StableDiffusionXLPAGPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
@@ -331,9 +331,9 @@ class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
|
||||
@@ -348,6 +348,6 @@ class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -215,9 +215,9 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
@@ -316,9 +316,9 @@ class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
|
||||
@@ -333,6 +333,6 @@ class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -220,9 +220,9 @@ class StableDiffusionXLPAGInpaintPipelineFastTests(
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
|
||||
f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
)
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
@@ -322,9 +322,9 @@ class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
|
||||
@@ -339,6 +339,6 @@ class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
|
||||
f"output is different from expected, {image_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
), f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
@@ -260,9 +260,9 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -276,15 +276,15 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -198,12 +198,12 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
|
||||
@@ -293,15 +293,15 @@ class StableDiffusionPipelineFastTests(
|
||||
inputs["sigmas"] = sigma_schedule
|
||||
output_sigmas = sd_pipe(**inputs).images
|
||||
|
||||
assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
|
||||
"ays timesteps and ays sigmas should have the same outputs"
|
||||
)
|
||||
assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
|
||||
"use ays timesteps should have different outputs"
|
||||
)
|
||||
assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
|
||||
"use ays sigmas should have different outputs"
|
||||
)
|
||||
assert (
|
||||
np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
|
||||
), "ays timesteps and ays sigmas should have the same outputs"
|
||||
assert (
|
||||
np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
|
||||
), "use ays timesteps should have different outputs"
|
||||
assert (
|
||||
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
|
||||
), "use ays sigmas should have different outputs"
|
||||
|
||||
def test_stable_diffusion_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
@@ -656,9 +656,9 @@ class StableDiffusionPipelineFastTests(
|
||||
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
|
||||
|
||||
assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
|
||||
"Enabling of FreeU should lead to different results."
|
||||
)
|
||||
assert not np.allclose(
|
||||
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
|
||||
), "Enabling of FreeU should lead to different results."
|
||||
|
||||
def test_freeu_disabled(self):
|
||||
components = self.get_dummy_components()
|
||||
@@ -681,9 +681,9 @@ class StableDiffusionPipelineFastTests(
|
||||
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
|
||||
).images
|
||||
|
||||
assert np.allclose(output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]), (
|
||||
"Disabling of FreeU should lead to results similar to the default pipeline results."
|
||||
)
|
||||
assert np.allclose(
|
||||
output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
|
||||
), "Disabling of FreeU should lead to results similar to the default pipeline results."
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -706,15 +706,15 @@ class StableDiffusionPipelineFastTests(
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -171,9 +171,9 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
@@ -187,15 +187,15 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_skip_guidance_layers(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -242,15 +242,15 @@ class StableDiffusionXLPipelineFastTests(
|
||||
inputs["sigmas"] = sigma_schedule
|
||||
output_sigmas = sd_pipe(**inputs).images
|
||||
|
||||
assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
|
||||
"ays timesteps and ays sigmas should have the same outputs"
|
||||
)
|
||||
assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
|
||||
"use ays timesteps should have different outputs"
|
||||
)
|
||||
assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
|
||||
"use ays sigmas should have different outputs"
|
||||
)
|
||||
assert (
|
||||
np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
|
||||
), "ays timesteps and ays sigmas should have the same outputs"
|
||||
assert (
|
||||
np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
|
||||
), "use ays timesteps should have different outputs"
|
||||
assert (
|
||||
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
|
||||
), "use ays sigmas should have different outputs"
|
||||
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
@@ -742,9 +742,9 @@ class StableDiffusionXLPipelineFastTests(
|
||||
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
|
||||
latents = pipe_1(**inputs_1).images[0]
|
||||
|
||||
assert expected_steps_1 == done_steps, (
|
||||
f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
)
|
||||
assert (
|
||||
expected_steps_1 == done_steps
|
||||
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
inputs_2 = {
|
||||
@@ -771,9 +771,9 @@ class StableDiffusionXLPipelineFastTests(
|
||||
pipe_3(**inputs_3).images[0]
|
||||
|
||||
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
|
||||
assert expected_steps == done_steps, (
|
||||
f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
)
|
||||
assert (
|
||||
expected_steps == done_steps
|
||||
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
|
||||
for steps in [7, 11, 20]:
|
||||
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
|
||||
|
||||
@@ -585,9 +585,9 @@ class StableDiffusionXLInpaintPipelineFastTests(
|
||||
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
|
||||
latents = pipe_1(**inputs_1).images[0]
|
||||
|
||||
assert expected_steps_1 == done_steps, (
|
||||
f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
)
|
||||
assert (
|
||||
expected_steps_1 == done_steps
|
||||
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
|
||||
inputs_2 = {
|
||||
**inputs,
|
||||
@@ -601,9 +601,9 @@ class StableDiffusionXLInpaintPipelineFastTests(
|
||||
pipe_3(**inputs_3).images[0]
|
||||
|
||||
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
|
||||
assert expected_steps == done_steps, (
|
||||
f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
)
|
||||
assert (
|
||||
expected_steps == done_steps
|
||||
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
|
||||
|
||||
for steps in [7, 11, 20]:
|
||||
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
|
||||
|
||||
@@ -167,9 +167,9 @@ class DownloadTests(unittest.TestCase):
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 15, "15 calls to files"
|
||||
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
|
||||
assert len(download_requests) == 32, (
|
||||
"2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
|
||||
)
|
||||
assert (
|
||||
len(download_requests) == 32
|
||||
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
@@ -179,9 +179,9 @@ class DownloadTests(unittest.TestCase):
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
|
||||
assert cache_requests.count("GET") == 1, "model info is only GET"
|
||||
assert len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
)
|
||||
assert (
|
||||
len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
def test_less_downloads_passed_object(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -217,9 +217,9 @@ class DownloadTests(unittest.TestCase):
|
||||
assert download_requests.count("HEAD") == 13, "13 calls to files"
|
||||
# 17 - 2 because no call to config or model file for `safety_checker`
|
||||
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
|
||||
assert len(download_requests) == 28, (
|
||||
"2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
|
||||
)
|
||||
assert (
|
||||
len(download_requests) == 28
|
||||
), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
@@ -229,9 +229,9 @@ class DownloadTests(unittest.TestCase):
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
|
||||
assert cache_requests.count("GET") == 1, "model info is only GET"
|
||||
assert len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
)
|
||||
assert (
|
||||
len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
def test_download_only_pytorch(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
@@ -191,12 +191,12 @@ class SDFunctionTesterMixin:
|
||||
inputs["output_type"] = "np"
|
||||
output_no_freeu = pipe(**inputs)[0]
|
||||
|
||||
assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
|
||||
"Enabling of FreeU should lead to different results."
|
||||
)
|
||||
assert np.allclose(output, output_no_freeu, atol=1e-2), (
|
||||
f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
|
||||
)
|
||||
assert not np.allclose(
|
||||
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
|
||||
), "Enabling of FreeU should lead to different results."
|
||||
assert np.allclose(
|
||||
output, output_no_freeu, atol=1e-2
|
||||
), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -217,12 +217,12 @@ class SDFunctionTesterMixin:
|
||||
and hasattr(component, "original_attn_processors")
|
||||
and component.original_attn_processors is not None
|
||||
):
|
||||
assert check_qkv_fusion_processors_exist(component), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
|
||||
"Something wrong with the attention processors concerning the fused QKV projections."
|
||||
)
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
component
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
component, component.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["return_dict"] = False
|
||||
@@ -235,15 +235,15 @@ class SDFunctionTesterMixin:
|
||||
image_disabled = pipe(**inputs)[0]
|
||||
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
class IPAdapterTesterMixin:
|
||||
@@ -909,9 +909,9 @@ class PipelineFromPipeTesterMixin:
|
||||
|
||||
for component in pipe_original.components.values():
|
||||
if hasattr(component, "attn_processors"):
|
||||
assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
|
||||
"`from_pipe` changed the attention processor in original pipeline."
|
||||
)
|
||||
assert all(
|
||||
type(proc) == AttnProcessor for proc in component.attn_processors.values()
|
||||
), "`from_pipe` changed the attention processor in original pipeline."
|
||||
|
||||
@require_accelerator
|
||||
@require_accelerate_version_greater("0.14.0")
|
||||
@@ -2569,12 +2569,12 @@ class PyramidAttentionBroadcastTesterMixin:
|
||||
image_slice_pab_disabled = output.flatten()
|
||||
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), (
|
||||
"PAB outputs should not differ much in specified timestep range."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), (
|
||||
"Outputs from normal inference and after disabling cache should not differ."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_enabled, atol=expected_atol
|
||||
), "PAB outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
|
||||
class FasterCacheTesterMixin:
|
||||
@@ -2639,12 +2639,12 @@ class FasterCacheTesterMixin:
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
|
||||
"FasterCache outputs should not differ much in specified timestep range."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
|
||||
"Outputs from normal inference and after disabling cache should not differ."
|
||||
)
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
|
||||
), "FasterCache outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
def test_faster_cache_state(self):
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
|
||||
@@ -191,12 +191,12 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
)
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
|
||||
f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
)
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_offloads(self):
|
||||
|
||||
@@ -357,9 +357,9 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
prediction_type=prediction_type,
|
||||
final_sigmas_type=final_sigmas_type,
|
||||
)
|
||||
assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
|
||||
f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
|
||||
)
|
||||
assert (
|
||||
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
|
||||
), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
|
||||
|
||||
def test_beta_sigmas(self):
|
||||
self.check_over_configs(use_beta_sigmas=True)
|
||||
|
||||
@@ -345,9 +345,9 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
|
||||
lower_order_final=lower_order_final,
|
||||
final_sigmas_type=final_sigmas_type,
|
||||
)
|
||||
assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
|
||||
f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
|
||||
)
|
||||
assert (
|
||||
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
|
||||
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
|
||||
|
||||
def test_beta_sigmas(self):
|
||||
self.check_over_configs(use_beta_sigmas=True)
|
||||
|
||||
@@ -188,9 +188,9 @@ class EDMDPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
assert not torch.isnan(sample).any(), (
|
||||
f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
|
||||
)
|
||||
assert (
|
||||
not torch.isnan(sample).any()
|
||||
), f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
|
||||
|
||||
def test_lower_order_final(self):
|
||||
self.check_over_configs(lower_order_final=True)
|
||||
|
||||
@@ -245,9 +245,9 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
interpolation_type=interpolation_type,
|
||||
final_sigmas_type=final_sigmas_type,
|
||||
)
|
||||
assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
|
||||
f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
|
||||
)
|
||||
assert (
|
||||
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
|
||||
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
|
||||
|
||||
def test_custom_sigmas(self):
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
@@ -260,9 +260,9 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
prediction_type=prediction_type,
|
||||
final_sigmas_type=final_sigmas_type,
|
||||
)
|
||||
assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
|
||||
f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
|
||||
)
|
||||
assert (
|
||||
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
|
||||
), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
|
||||
|
||||
def test_beta_sigmas(self):
|
||||
self.check_over_configs(use_beta_sigmas=True)
|
||||
|
||||
@@ -216,9 +216,9 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
prediction_type=prediction_type,
|
||||
timestep_spacing=timestep_spacing,
|
||||
)
|
||||
assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
|
||||
f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
|
||||
)
|
||||
assert (
|
||||
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
|
||||
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
|
||||
|
||||
def test_beta_sigmas(self):
|
||||
self.check_over_configs(use_beta_sigmas=True)
|
||||
|
||||
@@ -72,9 +72,9 @@ class SDSingleFileTesterMixin:
|
||||
continue
|
||||
|
||||
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
|
||||
assert isinstance(component, pipe.components[component_name].__class__), (
|
||||
f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
|
||||
)
|
||||
assert isinstance(
|
||||
component, pipe.components[component_name].__class__
|
||||
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
|
||||
|
||||
for param_name, param_value in component.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
@@ -85,9 +85,9 @@ class SDSingleFileTesterMixin:
|
||||
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
|
||||
pipe.components[component_name].config[param_name] = param_value
|
||||
|
||||
assert pipe.components[component_name].config[param_name] == param_value, (
|
||||
f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
|
||||
)
|
||||
assert (
|
||||
pipe.components[component_name].config[param_name] == param_value
|
||||
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
|
||||
|
||||
def test_single_file_components(self, pipe=None, single_file_pipe=None):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
@@ -253,9 +253,9 @@ class SDXLSingleFileTesterMixin:
|
||||
continue
|
||||
|
||||
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
|
||||
assert isinstance(component, pipe.components[component_name].__class__), (
|
||||
f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
|
||||
)
|
||||
assert isinstance(
|
||||
component, pipe.components[component_name].__class__
|
||||
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
|
||||
|
||||
for param_name, param_value in component.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
@@ -266,9 +266,9 @@ class SDXLSingleFileTesterMixin:
|
||||
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
|
||||
pipe.components[component_name].config[param_name] = param_value
|
||||
|
||||
assert pipe.components[component_name].config[param_name] == param_value, (
|
||||
f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
|
||||
)
|
||||
assert (
|
||||
pipe.components[component_name].config[param_name] == param_value
|
||||
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
|
||||
|
||||
def test_single_file_components(self, pipe=None, single_file_pipe=None):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
|
||||
@@ -60,9 +60,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
|
||||
@@ -87,9 +87,9 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading"
|
||||
)
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between pretrained loading and single file loading"
|
||||
|
||||
def test_single_file_in_type_variant_components(self):
|
||||
# `in` variant checkpoints require passing in a `config` parameter
|
||||
@@ -106,9 +106,9 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading"
|
||||
)
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between pretrained loading and single file loading"
|
||||
|
||||
def test_single_file_mix_type_variant_components(self):
|
||||
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
|
||||
@@ -121,6 +121,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading"
|
||||
)
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between pretrained loading and single file loading"
|
||||
|
||||
@@ -58,9 +58,9 @@ class ControlNetModelSingleFileTests(unittest.TestCase):
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
model_default = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
@@ -58,9 +58,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user