Compare commits

...

6 Commits

Author SHA1 Message Date
DN6 b365801c57 update 2025-04-09 15:34:42 +05:30
DN6 644147a198 Merge branch 'main' into ruff-update 2025-04-09 15:22:55 +05:30
DN6 c852f239f2 update 2025-03-08 08:17:14 +05:30
DN6 be861e236f update 2025-03-08 08:07:10 +05:30
DN6 2d744f0707 Merge branch 'main' into ruff-update 2025-03-08 08:05:08 +05:30
DN6 41c7e72d44 update 2025-02-27 17:08:37 +05:30
200 changed files with 4753 additions and 4694 deletions
@@ -839,9 +839,9 @@ class TokenEmbeddingsHandler:
idx = 0 idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all( assert all(isinstance(tok, str) for tok in inserting_toks), (
isinstance(tok, str) for tok in inserting_toks "All elements in inserting_toks should be strings."
), "All elements in inserting_toks should be strings." )
self.inserting_toks = inserting_toks self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks} special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -1605,7 +1605,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -200,7 +200,8 @@ Special VAE used for training: {vae_path}.
"diffusers", "diffusers",
"diffusers-training", "diffusers-training",
lora, lora,
"template:sd-lora" "stable-diffusion", "template:sd-lora",
"stable-diffusion",
"stable-diffusion-diffusers", "stable-diffusion-diffusers",
] ]
model_card = populate_model_card(model_card, tags=tags) model_card = populate_model_card(model_card, tags=tags)
@@ -724,9 +725,9 @@ class TokenEmbeddingsHandler:
idx = 0 idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all( assert all(isinstance(tok, str) for tok in inserting_toks), (
isinstance(tok, str) for tok in inserting_toks "All elements in inserting_toks should be strings."
), "All elements in inserting_toks should be strings." )
self.inserting_toks = inserting_toks self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks} special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -746,9 +747,9 @@ class TokenEmbeddingsHandler:
.to(dtype=self.dtype) .to(dtype=self.dtype)
* std_token_embedding * std_token_embedding
) )
self.embeddings_settings[ self.embeddings_settings[f"original_embeddings_{idx}"] = (
f"original_embeddings_{idx}" text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool) inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1322,7 +1323,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -890,9 +890,9 @@ class TokenEmbeddingsHandler:
idx = 0 idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all( assert all(isinstance(tok, str) for tok in inserting_toks), (
isinstance(tok, str) for tok in inserting_toks "All elements in inserting_toks should be strings."
), "All elements in inserting_toks should be strings." )
self.inserting_toks = inserting_toks self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks} special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -912,9 +912,9 @@ class TokenEmbeddingsHandler:
.to(dtype=self.dtype) .to(dtype=self.dtype)
* std_token_embedding * std_token_embedding
) )
self.embeddings_settings[ self.embeddings_settings[f"original_embeddings_{idx}"] = (
f"original_embeddings_{idx}" text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool) inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1647,7 +1647,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -1138,7 +1138,7 @@ def main(args):
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+1 -1
View File
@@ -1159,7 +1159,7 @@ def main(args):
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -701,7 +701,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
raise ValueError("`max_tile_size` cannot be None.") raise ValueError("`max_tile_size` cannot be None.")
elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280): elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
raise ValueError( raise ValueError(
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}." f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}."
) )
if tile_gaussian_sigma is None: if tile_gaussian_sigma is None:
raise ValueError("`tile_gaussian_sigma` cannot be None.") raise ValueError("`tile_gaussian_sigma` cannot be None.")
@@ -488,7 +488,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -496,7 +496,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+6 -6
View File
@@ -907,12 +907,12 @@ def create_controller(
# reweight # reweight
if edit_type == "reweight": if edit_type == "reweight":
assert ( assert equalizer_words is not None and equalizer_strengths is not None, (
equalizer_words is not None and equalizer_strengths is not None "To use reweight edit, please specify equalizer_words and equalizer_strengths."
), "To use reweight edit, please specify equalizer_words and equalizer_strengths." )
assert len(equalizer_words) == len( assert len(equalizer_words) == len(equalizer_strengths), (
equalizer_strengths "equalizer_words and equalizer_strengths must be of same length."
), "equalizer_words and equalizer_strengths must be of same length." )
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight( return AttentionReweight(
prompts, prompts,
@@ -1028,7 +1028,7 @@ class StableDiffusionXL_AE_Pipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -1036,7 +1036,7 @@ class StableDiffusionXL_AE_Pipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
+1 -2
View File
@@ -288,8 +288,7 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
timesteps = np.array(timesteps, dtype=np.int64) timesteps = np.array(timesteps, dtype=np.int64)
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter # Set alpha parameter
if "lora_down" in kohya_key: if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha' alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict return kohya_ss_state_dict
@@ -901,7 +901,7 @@ def main(args):
unet_ = accelerator.unwrap_model(unet) unet_ = accelerator.unwrap_model(unet)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = { unet_state_dict = {
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
} }
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter # Set alpha parameter
if "lora_down" in kohya_key: if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha' alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict return kohya_ss_state_dict
+5 -3
View File
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0 total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images) pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open( with (
f"{class_data_dir}/images.txt", "w" open(f"{class_data_dir}/caption.txt", "w") as f1,
) as f3: open(f"{class_data_dir}/urls.txt", "w") as f2,
open(f"{class_data_dir}/images.txt", "w") as f3,
):
while total < num_class_images: while total < num_class_images:
images = class_images[count] images = class_images[count]
count += 1 count += 1
@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists(): if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True) class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior: if args.real_prior:
assert ( assert (class_images_dir / "images").exists(), (
class_images_dir / "images" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
len(list((class_images_dir / "images").iterdir())) == args.num_class_images f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert (class_images_dir / "caption.txt").exists(), (
class_images_dir / "caption.txt" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert (class_images_dir / "images.txt").exists(), (
class_images_dir / "images.txt" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt") concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt") concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept args.concepts_list[i] = concept
+1 -1
View File
@@ -1014,7 +1014,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
+1 -1
View File
@@ -982,7 +982,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -1294,7 +1294,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1053,7 +1053,7 @@ def main(args):
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir) lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1064,7 +1064,7 @@ def main(args):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir) lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1355,7 +1355,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -118,7 +118,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} # {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
@@ -1286,7 +1286,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
assert ( assert pipeline.transformer.config.in_channels == initial_channels * 2, (
pipeline.transformer.config.in_channels == initial_channels * 2 f"{pipeline.transformer.config.in_channels=}"
), f"{pipeline.transformer.config.in_channels=}" )
pipeline.to(accelerator.device) pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
@@ -954,7 +954,7 @@ def main(args):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir) lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = { transformer_lora_state_dict = {
f'{k.replace("transformer.", "")}': v f"{k.replace('transformer.', '')}": v
for k, v in lora_state_dict.items() for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k if k.startswith("transformer.") and "lora" in k
} }
+3 -3
View File
@@ -1081,9 +1081,9 @@ class AutoConfig:
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}" f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
) )
pretrained_model_name_or_paths[ pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
pretrained_model_name_or_paths.index(search_word) textual_inversion_path.model_path
] = textual_inversion_path.model_path )
self.load_textual_inversion( self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors="pt", return_tensors="pt",
) )
tokens = batch_encoding["input_ids"] tokens = batch_encoding["input_ids"]
assert ( assert torch.count_nonzero(tokens - 49407) == 2, (
torch.count_nonzero(tokens - 49407) == 2 f"String '{string}' maps to more than a single token. Please use another string"
), f"String '{string}' maps to more than a single token. Please use another string" )
return tokens[0, 1] return tokens[0, 1]
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert ( assert H == self.img_size[0] and W == self.img_size[1], (
H == self.img_size[0] and W == self.img_size[1] f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." )
x = self.proj(x).flatten(2).permute(0, 2, 1) x = self.proj(x).flatten(2).permute(0, 2, 1)
return x return x
@@ -803,21 +803,20 @@ def parse_args(input_args=None):
"--control_type", "--control_type",
type=str, type=str,
default="canny", default="canny",
help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."), help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
) )
parser.add_argument( parser.add_argument(
"--transformer_layers_per_block", "--transformer_layers_per_block",
type=str, type=str,
default=None, default=None,
help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."), help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
) )
parser.add_argument( parser.add_argument(
"--old_style_controlnet", "--old_style_controlnet",
action="store_true", action="store_true",
default=False, default=False,
help=( help=(
"Use the old style controlnet, which is a single transformer layer with" "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
" a single head. Defaults to False."
), ),
) )
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
@@ -683,7 +683,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
@@ -790,7 +790,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -783,7 +783,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
File diff suppressed because one or more lines are too long
+17 -22
View File
@@ -26,8 +26,7 @@
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload 2\n", "%autoreload 2\n",
"\n", "\n",
"import torch\n", "from diffusers import StableDiffusionGLIGENPipeline"
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
] ]
}, },
{ {
@@ -36,28 +35,25 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"import diffusers\n", "import diffusers\n",
"from diffusers import (\n", "from diffusers import (\n",
" AutoencoderKL,\n", " AutoencoderKL,\n",
" DDPMScheduler,\n", " DDPMScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n", " EulerDiscreteScheduler,\n",
" UNet2DConditionModel,\n",
")\n", ")\n",
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", "\n",
"\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n", "\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n", "\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n", "tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n", "noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n", "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n", "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n", "# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n", "# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n", "# )\n",
@@ -71,9 +67,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"unet = UNet2DConditionModel.from_pretrained(\n", "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
] ]
}, },
{ {
@@ -108,6 +102,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n",
"\n",
"\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n", "# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n", "# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n", "\n",
@@ -117,10 +114,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n", "# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n", "# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n", "\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n", "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n", "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n", "\n",
"boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n", "boxes = boxes / 512\n",
@@ -179,7 +174,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "densecaption", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@@ -197,5 +192,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 4
} }
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match # Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",") instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",") instance_prompt = args.instance_prompt.split(",")
assert all( assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] "Instance data dir and prompt inputs are not of the same length."
), "Instance data dir and prompt inputs are not of the same length." )
if args.with_prior_preservation: if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",") class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None) negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts args.validation_negative_prompt = negative_validation_prompts
assert num_of_validation_prompts == len( assert num_of_validation_prompts == len(negative_validation_prompts), (
negative_validation_prompts "The length of negative prompts for validation is greater than the number of validation prompts."
), "The length of negative prompts for validation is greater than the number of validation prompts." )
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator) index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -663,8 +663,7 @@ class PromptDiffusionPipeline(
self.check_image(image, prompt, prompt_embeds) self.check_image(image, prompt, prompt_embeds)
else: else:
raise ValueError( raise ValueError(
f"You have passed a list of images of length {len(image_pair)}." f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
f"Make sure the list size equals to two."
) )
# Check `controlnet_conditioning_scale` # Check `controlnet_conditioning_scale`
@@ -1057,7 +1057,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
@@ -1021,7 +1021,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -118,7 +118,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} # {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
@@ -1336,7 +1336,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -750,7 +750,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -767,7 +767,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
index_no_updates_2 orig_embeds_params_2[index_no_updates_2]
] = orig_embeds_params_2[index_no_updates_2] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
+3 -3
View File
@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path} --model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1 --checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir} --output_dir {tmpdir}
--seed=0 --seed=0
""".split() """.split()
@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path} --model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1 --checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir} --output_dir {tmpdir}
--use_ema --use_ema
--seed=0 --seed=0
@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir} --output_dir {tmpdir}
--checkpointing_steps=2 --checkpointing_steps=2
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2 --checkpoints_total_limit=2
--seed=0 --seed=0
""".split() """.split()
+6 -6
View File
@@ -653,15 +653,15 @@ def main():
try: try:
# Gets the resolution of the timm transformation after centercrop # Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1] timm_centercrop_transform = timm_transform.transforms[1]
assert isinstance( assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
timm_centercrop_transform, transforms.CenterCrop f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." )
timm_model_resolution = timm_centercrop_transform.size[0] timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization # Gets final normalization
timm_model_normalization = timm_transform.transforms[-1] timm_model_normalization = timm_transform.transforms[-1]
assert isinstance( assert isinstance(timm_model_normalization, transforms.Normalize), (
timm_model_normalization, transforms.Normalize f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." )
except AssertionError as e: except AssertionError as e:
raise NotImplementedError(e) raise NotImplementedError(e)
# Enable flash attention if asked # Enable flash attention if asked
+1 -1
View File
@@ -3,7 +3,7 @@ line-length = 119
[tool.ruff.lint] [tool.ruff.lint]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"] ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files. # Ignore import violations in all `__init__.py` files.
@@ -261,9 +261,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0] model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path): if not os.path.isfile(args.model_path):
assert ( assert model_name == args.model_path, (
model_name == args.model_path f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" )
args.model_path = download(model_name) args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"] sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -290,9 +290,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items(): for key, value in renamed_state_dict.items():
assert ( assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" )
if key == "time_proj.weight": if key == "time_proj.weight":
value = value.squeeze() value = value.squeeze()
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config( model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer" "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
) )
model_config[ model_config["use_style_cond_and_image_meta_size"] = (
"use_style_cond_and_image_meta_size" args.use_style_cond_and_image_meta_size
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False ) ### version <= v1.1: True; version >= v1.2: False
print(model_config) print(model_config)
for key in state_dict: for key in state_dict:
+4 -5
View File
@@ -13,15 +13,14 @@ def main(args):
state_dict = state_dict[args.load_key] state_dict = state_dict[args.load_key]
except KeyError: except KeyError:
raise KeyError( raise KeyError(
f"{args.load_key} not found in the checkpoint." f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
f"Please load from the following keys:{state_dict.keys()}"
) )
device = "cuda" device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer") model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
model_config[ model_config["use_style_cond_and_image_meta_size"] = (
"use_style_cond_and_image_meta_size" args.use_style_cond_and_image_meta_size
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False ) ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim # input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict: for key in state_dict:
+3 -3
View File
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"] block_out_channels = original_config["channels"]
assert ( assert len(set(original_config["depths"])) == 1, (
len(set(original_config["depths"])) == 1 "UNet2DConditionModel currently do not support blocks with different number of layers"
), "UNet2DConditionModel currently do not support blocks with different number of layers" )
layers_per_block = original_config["depths"][0] layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"] class_labels_dim = original_config["mapping_cond_dim"]
+9 -7
View File
@@ -223,7 +223,9 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
f"blocks.{block + 1}.proj.weight" f"blocks.{block + 1}.proj.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias") new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.proj.bias"
)
# Convert block_out (MochiMidBlock3D) # Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3 for i in range(3): # layers_per_block[0] = 3
@@ -303,9 +305,9 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for i in range(down_block_layers[block]): for i in range(down_block_layers[block]):
# Convert resnets # Convert resnets
new_state_dict[ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias" f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
) )
@@ -315,9 +317,9 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias" f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
) )
new_state_dict[ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight") )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias" f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
) )
+6 -6
View File
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor # TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[ new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] )
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
) )
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[ new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values(): if ["conv.bias", "conv.weight"] in output_block_list.values():
+12 -12
View File
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
def vqvae_model_from_original_config(original_config): def vqvae_model_from_original_config(original_config):
assert ( assert original_config["target"] in PORTED_VQVAES, (
original_config["target"] in PORTED_VQVAES f"{original_config['target']} has not yet been ported to diffusers."
), f"{original_config['target']} has not yet been ported to diffusers." )
original_config = original_config["params"] original_config = original_config["params"]
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
def transformer_model_from_original_config( def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config original_diffusion_config, original_transformer_config, original_content_embedding_config
): ):
assert ( assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
original_diffusion_config["target"] in PORTED_DIFFUSIONS f"{original_diffusion_config['target']} has not yet been ported to diffusers."
), f"{original_diffusion_config['target']} has not yet been ported to diffusers." )
assert ( assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
original_transformer_config["target"] in PORTED_TRANSFORMERS f"{original_transformer_config['target']} has not yet been ported to diffusers."
), f"{original_transformer_config['target']} has not yet been ported to diffusers." )
assert ( assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers." )
original_diffusion_config = original_diffusion_config["params"] original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"] original_transformer_config = original_transformer_config["params"]
+1 -1
View File
@@ -122,7 +122,7 @@ _deps = [
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"python>=3.8.0", "python>=3.8.0",
"ruff==0.1.5", "ruff==0.9.10",
"safetensors>=0.3.1", "safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19", "GitPython<3.1.19",
+1 -1
View File
@@ -29,7 +29,7 @@ deps = {
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0", "python": "python>=3.8.0",
"ruff": "ruff==0.1.5", "ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1", "safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
+1 -2
View File
@@ -295,8 +295,7 @@ class IPAdapterMixin:
): ):
if len(scale_configs) != len(attn_processor.scale): if len(scale_configs) != len(attn_processor.scale):
raise ValueError( raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to " f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
f"{len(attn_processor.scale)} IP-Adapter."
) )
elif len(scale_configs) == 1: elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale) scale_configs = scale_configs * len(attn_processor.scale)
+33 -33
View File
@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present. # Store DoRA scale if present.
if dora_present_in_unet: if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
# Handle text encoder LoRAs. # Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
) )
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
elif lora_name.startswith("lora_te2_"): elif lora_name.startswith("lora_te2_"):
te2_state_dict[ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
# Store alpha if present. # Store alpha if present.
if lora_name_alpha in state_dict: if lora_name_alpha in state_dict:
@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
for lora_key in ["lora_A", "lora_B"]: for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in ## time_text_embed.timestep_embedder <- time_in
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") )
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") )
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") )
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") )
## time_text_embed.text_embedder <- vector_in ## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
# guidance # guidance
has_guidance = any("guidance" in k for k in original_state_dict) has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance: if has_guidance:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") )
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") )
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") )
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") )
# context_embedder # context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
+1 -1
View File
@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
@@ -41,7 +42,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["cache_utils"] = ["CacheMixin"] _import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
+1 -1
View File
@@ -205,7 +205,7 @@ def load_state_dict(
) from e ) from e
except (UnicodeDecodeError, ValueError): except (UnicodeDecodeError, ValueError):
raise OSError( raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
) )
@@ -211,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
def _init_vectorized_inputs(self, norm_type): def _init_vectorized_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert ( assert self.config.num_vector_embeds is not None, (
self.config.num_vector_embeds is not None "Transformer2DModel over discrete input must provide num_embed"
), "Transformer2DModel over discrete input must provide num_embed" )
self.height = self.config.sample_size self.height = self.config.sample_size
self.width = self.config.sample_size self.width = self.config.sample_size
@@ -791,7 +791,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
if transcription is None: if transcription is None:
if self.text_encoder_2.config.model_type == "vits": if self.text_encoder_2.config.model_type == "vits":
raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription") raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
elif transcription is not None and ( elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list) not isinstance(transcription, str) and not isinstance(transcription, list)
): ):
@@ -657,7 +657,7 @@ class StableDiffusionControlNetInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -665,7 +665,7 @@ class StableDiffusionControlNetInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple # `prompt` needs more sophisticated handling when there are multiple
# conditionings. # conditionings.
@@ -507,7 +507,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -515,7 +515,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -574,7 +574,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -582,7 +582,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+3 -3
View File
@@ -341,9 +341,9 @@ class AnimateDiffFreeNoiseMixin:
start_tensor = negative_prompt_embeds[i].unsqueeze(0) start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
negative_prompt_interpolation_embeds[ negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
start_frame : end_frame + 1 self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) )
prompt_embeds = prompt_interpolation_embeds prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds negative_prompt_embeds = negative_prompt_interpolation_embeds
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
""" """
_load_connected_pipes = True _load_connected_pipes = True
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq" model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
_exclude_from_cpu_offload = ["prior_prior"] _exclude_from_cpu_offload = ["prior_prior"]
def __init__( def __init__(
@@ -95,13 +95,13 @@ class OmniGenMultiModalProcessor:
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids)) unique_image_ids = sorted(set(image_ids))
assert unique_image_ids == list( assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
range(1, len(unique_image_ids) + 1) f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" )
# total images must be the same as the number of image tags # total images must be the same as the number of image tags
assert ( assert len(unique_image_ids) == len(input_images), (
len(unique_image_ids) == len(input_images) f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" )
input_images = [input_images[x - 1] for x in image_ids] input_images = [input_images[x - 1] for x in image_ids]
@@ -604,7 +604,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -612,7 +612,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple # `prompt` needs more sophisticated handling when there are multiple
# conditionings. # conditionings.
@@ -683,7 +683,7 @@ class StableDiffusionPAGInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -691,7 +691,7 @@ class StableDiffusionPAGInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -737,7 +737,7 @@ class StableDiffusionXLPAGInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -745,7 +745,7 @@ class StableDiffusionXLPAGInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -323,9 +323,7 @@ def maybe_raise_or_warn(
model_cls = unwrapped_sub_model.__class__ model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj): if not issubclass(model_cls, expected_class_obj):
raise ValueError( raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
)
else: else:
logger.warning( logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+6 -6
View File
@@ -983,9 +983,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
fields = torch.cat(fields, dim=1) fields = torch.cat(fields, dim=1)
fields = fields.float() fields = fields.float()
assert ( assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
len(fields.shape) == 3 and fields.shape[-1] == 1 f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" )
fields = fields.reshape(1, *([grid_size] * 3)) fields = fields.reshape(1, *([grid_size] * 3))
@@ -1039,9 +1039,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
textures = textures.float() textures = textures.float()
# 3.3 augument the mesh with texture data # 3.3 augument the mesh with texture data
assert len(textures.shape) == 3 and textures.shape[-1] == len( assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
texture_channels f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" )
for m, texture in zip(raw_meshes, textures): for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)] texture = texture[: len(m.verts)]
@@ -660,7 +660,7 @@ class StableDiffusionInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -668,7 +668,7 @@ class StableDiffusionInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -741,7 +741,7 @@ class StableDiffusionXLInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -749,7 +749,7 @@ class StableDiffusionXLInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -334,7 +334,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
) )
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}") raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if height % 16 != 0 or width % 16 != 0: if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+4 -8
View File
@@ -215,19 +215,15 @@ class DiffusersQuantizer(ABC):
) )
@abstractmethod @abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs): def _process_model_before_weight_loading(self, model, **kwargs): ...
...
@abstractmethod @abstractmethod
def _process_model_after_weight_loading(self, model, **kwargs): def _process_model_after_weight_loading(self, model, **kwargs): ...
...
@property @property
@abstractmethod @abstractmethod
def is_serializable(self): def is_serializable(self): ...
...
@property @property
@abstractmethod @abstractmethod
def is_trainable(self): def is_trainable(self): ...
...
@@ -203,8 +203,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
timesteps = np.array(timesteps, dtype=np.int64) timesteps = np.array(timesteps, dtype=np.int64)
+1 -2
View File
@@ -279,8 +279,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
timesteps = np.array(timesteps, dtype=np.int64) timesteps = np.array(timesteps, dtype=np.int64)
@@ -289,8 +289,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
timesteps = np.array(timesteps, dtype=np.int64) timesteps = np.array(timesteps, dtype=np.int64)
+1 -2
View File
@@ -413,8 +413,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
+1 -2
View File
@@ -431,8 +431,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
+2 -2
View File
@@ -241,7 +241,7 @@ def _set_state_dict_into_text_encoder(
""" """
text_encoder_state_dict = { text_encoder_state_dict = {
f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
} }
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
@@ -583,7 +583,7 @@ class EMAModel:
""" """
if self.temp_stored_params is None: if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
if self.foreach: if self.foreach:
torch._foreach_copy_( torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
+1 -2
View File
@@ -60,8 +60,7 @@ def _get_default_logging_level() -> int:
return log_levels[env_level_str] return log_levels[env_level_str]
else: else:
logging.getLogger().warning( logging.getLogger().warning(
f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
f"has to be one of: { ', '.join(log_levels.keys()) }"
) )
return _default_log_level return _default_log_level
+1 -1
View File
@@ -334,7 +334,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key: if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha' alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict return kohya_ss_state_dict
+1 -1
View File
@@ -1027,7 +1027,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
process.join(timeout=timeout) process.join(timeout=timeout)
if results["error"] is not None: if results["error"] is not None:
test_case.fail(f'{results["error"]}') test_case.fail(f"{results['error']}")
class CaptureLogger: class CaptureLogger:
+2 -9
View File
@@ -168,9 +168,7 @@ class HookTests(unittest.TestCase):
registry.register_hook(MultiplyHook(2), "multiply_hook") registry.register_hook(MultiplyHook(2), "multiply_hook")
registry_repr = repr(registry) registry_repr = repr(registry)
expected_repr = ( expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
"HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
)
self.assertEqual(len(registry.hooks), 2) self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
@@ -285,12 +283,7 @@ class HookTests(unittest.TestCase):
self.model(input) self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "") output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = ( expected_invocation_order_log = (
( ("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n")
"MultiplyHook pre_forward\n"
"AddHook pre_forward\n"
"AddHook post_forward\n"
"MultiplyHook post_forward\n"
)
.replace(" ", "") .replace(" ", "")
.replace("\n", "") .replace("\n", "")
) )
+6 -6
View File
@@ -299,9 +299,9 @@ class ModelUtilsTest(unittest.TestCase):
) )
download_requests = [r.method for r in m.request_history] download_requests = [r.method for r in m.request_history]
assert ( assert download_requests.count("HEAD") == 3, (
download_requests.count("HEAD") == 3 "3 HEAD requests one for config, one for model, and one for shard index file."
), "3 HEAD requests one for config, one for model, and one for shard index file." )
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m: with requests_mock.mock(real_http=True) as m:
@@ -313,9 +313,9 @@ class ModelUtilsTest(unittest.TestCase):
) )
cache_requests = [r.method for r in m.request_history] cache_requests = [r.method for r in m.request_history]
assert ( assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
"HEAD" == cache_requests[0] and len(cache_requests) == 2 "We should call only `model_info` to check for commit hash and knowing if shard index is present."
), "We should call only `model_info` to check for commit hash and knowing if shard index is present." )
def test_weight_overwrite(self): def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
@@ -92,9 +92,9 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
assert ( assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor" "xformers is not enabled"
), "xformers is not enabled" )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self): def test_set_attn_processor_for_determinism(self):
@@ -167,9 +167,9 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
assert ( assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor" "xformers is not enabled"
), "xformers is not enabled" )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self): def test_set_attn_processor_for_determinism(self):
@@ -654,22 +654,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
assert full_cond_keepallmask_out.allclose( assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
full_cond_out, rtol=1e-05, atol=1e-05 "a 'keep all' mask should give the same result as no mask"
), "a 'keep all' mask should give the same result as no mask" )
trunc_cond = cond[:, :-1, :] trunc_cond = cond[:, :-1, :]
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
assert not trunc_cond_out.allclose( assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
full_cond_out, rtol=1e-05, atol=1e-05 "discarding the last token from our cond should change the result"
), "discarding the last token from our cond should change the result" )
batch, tokens, _ = cond.shape batch, tokens, _ = cond.shape
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
assert masked_cond_out.allclose( assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
trunc_cond_out, rtol=1e-05, atol=1e-05 "masking the last token from our cond should be equivalent to truncating that token out of the condition"
), "masking the last token from our cond should be equivalent to truncating that token out of the condition" )
# see diffusers.models.attention_processor::Attention#prepare_attention_mask # see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -697,9 +697,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose( assert trunc_mask_out.allclose(keeplast_out), (
keeplast_out "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." )
def test_custom_diffusion_processors(self): def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
@@ -1114,12 +1114,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
with torch.no_grad(): with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose( assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4 "LoRA injected UNet should produce different results."
), "LoRA injected UNet should produce different results." )
assert torch.allclose( assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4 "Loading from a saved checkpoint should produce identical results."
), "Loading from a saved checkpoint should produce identical results." )
@require_peft_backend @require_peft_backend
def test_save_attn_procs_raise_warning(self): def test_save_attn_procs_raise_warning(self):
+15 -15
View File
@@ -65,9 +65,9 @@ class ImageProcessorTest(unittest.TestCase):
) )
out_np = self.to_np(out) out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np in_np = (input_np * 255).round() if output_type == "pil" else input_np
assert ( assert np.abs(in_np - out_np).max() < 1e-6, (
np.abs(in_np - out_np).max() < 1e-6 f"decoded output does not match input for output_type {output_type}"
), f"decoded output does not match input for output_type {output_type}" )
def test_vae_image_processor_np(self): def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -78,9 +78,9 @@ class ImageProcessorTest(unittest.TestCase):
out_np = self.to_np(out) out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np in_np = (input_np * 255).round() if output_type == "pil" else input_np
assert ( assert np.abs(in_np - out_np).max() < 1e-6, (
np.abs(in_np - out_np).max() < 1e-6 f"decoded output does not match input for output_type {output_type}"
), f"decoded output does not match input for output_type {output_type}" )
def test_vae_image_processor_pil(self): def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -93,9 +93,9 @@ class ImageProcessorTest(unittest.TestCase):
for i, o in zip(input_pil, out): for i, o in zip(input_pil, out):
in_np = np.array(i) in_np = np.array(i)
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round() out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
assert ( assert np.abs(in_np - out_np).max() < 1e-6, (
np.abs(in_np - out_np).max() < 1e-6 f"decoded output does not match input for output_type {output_type}"
), f"decoded output does not match input for output_type {output_type}" )
def test_preprocess_input_3d(self): def test_preprocess_input_3d(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
@@ -293,9 +293,9 @@ class ImageProcessorTest(unittest.TestCase):
scale = 2 scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale) out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale) exp_pt_shape = (b, c, h // scale, w // scale)
assert ( assert out_pt.shape == exp_pt_shape, (
out_pt.shape == exp_pt_shape f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'." )
def test_vae_image_processor_resize_np(self): def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1) image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
@@ -305,6 +305,6 @@ class ImageProcessorTest(unittest.TestCase):
input_np = self.to_np(input_pt) input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale) out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c) exp_np_shape = (b, h // scale, w // scale, c)
assert ( assert out_np.shape == exp_np_shape, (
out_np.shape == exp_np_shape f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'." )
+1 -2
View File
@@ -126,8 +126,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators") @unittest.skip("aMUSEd does not support lists of generators")
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self): ...
...
@slow @slow
@@ -126,8 +126,7 @@ class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators") @unittest.skip("aMUSEd does not support lists of generators")
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self): ...
...
@slow @slow
@@ -130,8 +130,7 @@ class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
@unittest.skip("aMUSEd does not support lists of generators") @unittest.skip("aMUSEd does not support lists of generators")
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self): ...
...
@slow @slow
@@ -106,9 +106,9 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level. # to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist( assert check_qkv_fusion_processors_exist(pipe.transformer), (
pipe.transformer "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." )
assert check_qkv_fusion_matches_attn_procs_length( assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections." ), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -122,15 +122,15 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
image = pipe(**inputs).images image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1] image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose( assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 "Fusion of QKV projections shouldn't affect the outputs."
), "Fusion of QKV projections shouldn't affect the outputs." )
assert np.allclose( assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." )
assert np.allclose( assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 "Original outputs should match when fused QKV projections are disabled."
), "Original outputs should match when fused QKV projections are disabled." )
@unittest.skip("xformers attention processor does not exist for AuraFlow") @unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self): def test_xformers_attention_forwardGenerator_pass(self):
@@ -195,9 +195,9 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
[0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007] [0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
) )
assert ( assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}" )
@unittest.skip("Test not supported because of complexities in deriving query_embeds.") @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self): def test_encode_prompt_works_in_isolation(self):
+12 -12
View File
@@ -299,9 +299,9 @@ class CogVideoXPipelineFastTests(
original_image_slice = frames[0, -2:, -1, -3:, -3:] original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections() pipe.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist( assert check_qkv_fusion_processors_exist(pipe.transformer), (
pipe.transformer "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." )
assert check_qkv_fusion_matches_attn_procs_length( assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections." ), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -315,15 +315,15 @@ class CogVideoXPipelineFastTests(
frames = pipe(**inputs).frames frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:] image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
assert np.allclose( assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 "Fusion of QKV projections shouldn't affect the outputs."
), "Fusion of QKV projections shouldn't affect the outputs." )
assert np.allclose( assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." )
assert np.allclose( assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 "Original outputs should match when fused QKV projections are disabled."
), "Original outputs should match when fused QKV projections are disabled." )
@slow @slow
@@ -299,9 +299,9 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
original_image_slice = frames[0, -2:, -1, -3:, -3:] original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections() pipe.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist( assert check_qkv_fusion_processors_exist(pipe.transformer), (
pipe.transformer "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." )
assert check_qkv_fusion_matches_attn_procs_length( assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections." ), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -315,12 +315,12 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
frames = pipe(**inputs).frames frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:] image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
assert np.allclose( assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 "Fusion of QKV projections shouldn't affect the outputs."
), "Fusion of QKV projections shouldn't affect the outputs." )
assert np.allclose( assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." )
assert np.allclose( assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 "Original outputs should match when fused QKV projections are disabled."
), "Original outputs should match when fused QKV projections are disabled." )

Some files were not shown because too many files have changed in this diff Show More