Compare commits

...

6 Commits

Author SHA1 Message Date
DN6 b365801c57 update 2025-04-09 15:34:42 +05:30
DN6 644147a198 Merge branch 'main' into ruff-update 2025-04-09 15:22:55 +05:30
DN6 c852f239f2 update 2025-03-08 08:17:14 +05:30
DN6 be861e236f update 2025-03-08 08:07:10 +05:30
DN6 2d744f0707 Merge branch 'main' into ruff-update 2025-03-08 08:05:08 +05:30
DN6 41c7e72d44 update 2025-02-27 17:08:37 +05:30
200 changed files with 4753 additions and 4694 deletions
@@ -839,9 +839,9 @@ class TokenEmbeddingsHandler:
idx = 0 idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all( assert all(isinstance(tok, str) for tok in inserting_toks), (
isinstance(tok, str) for tok in inserting_toks "All elements in inserting_toks should be strings."
), "All elements in inserting_toks should be strings." )
self.inserting_toks = inserting_toks self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks} special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -1605,7 +1605,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -200,7 +200,8 @@ Special VAE used for training: {vae_path}.
"diffusers", "diffusers",
"diffusers-training", "diffusers-training",
lora, lora,
"template:sd-lora" "stable-diffusion", "template:sd-lora",
"stable-diffusion",
"stable-diffusion-diffusers", "stable-diffusion-diffusers",
] ]
model_card = populate_model_card(model_card, tags=tags) model_card = populate_model_card(model_card, tags=tags)
@@ -724,9 +725,9 @@ class TokenEmbeddingsHandler:
idx = 0 idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all( assert all(isinstance(tok, str) for tok in inserting_toks), (
isinstance(tok, str) for tok in inserting_toks "All elements in inserting_toks should be strings."
), "All elements in inserting_toks should be strings." )
self.inserting_toks = inserting_toks self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks} special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -746,9 +747,9 @@ class TokenEmbeddingsHandler:
.to(dtype=self.dtype) .to(dtype=self.dtype)
* std_token_embedding * std_token_embedding
) )
self.embeddings_settings[ self.embeddings_settings[f"original_embeddings_{idx}"] = (
f"original_embeddings_{idx}" text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool) inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1322,7 +1323,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -890,9 +890,9 @@ class TokenEmbeddingsHandler:
idx = 0 idx = 0
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
assert all( assert all(isinstance(tok, str) for tok in inserting_toks), (
isinstance(tok, str) for tok in inserting_toks "All elements in inserting_toks should be strings."
), "All elements in inserting_toks should be strings." )
self.inserting_toks = inserting_toks self.inserting_toks = inserting_toks
special_tokens_dict = {"additional_special_tokens": self.inserting_toks} special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
@@ -912,9 +912,9 @@ class TokenEmbeddingsHandler:
.to(dtype=self.dtype) .to(dtype=self.dtype)
* std_token_embedding * std_token_embedding
) )
self.embeddings_settings[ self.embeddings_settings[f"original_embeddings_{idx}"] = (
f"original_embeddings_{idx}" text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() )
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool) inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -1647,7 +1647,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
+1 -1
View File
@@ -720,7 +720,7 @@ def main(args):
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num training steps = {args.max_train_steps}") logger.info(f" Num training steps = {args.max_train_steps}")
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
@@ -1138,7 +1138,7 @@ def main(args):
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+1 -1
View File
@@ -1159,7 +1159,7 @@ def main(args):
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1103,7 +1103,7 @@ class AdaptiveMaskInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `default_mask_image` or `image` input." " `pipeline.unet` or your `default_mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
+1 -1
View File
@@ -686,7 +686,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
+1 -1
View File
@@ -362,7 +362,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
+2 -2
View File
@@ -1120,7 +1120,7 @@ class LLMGroundedDiffusionPipeline(
if verbose: if verbose:
logger.info( logger.info(
f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}" f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
) )
try: try:
@@ -1184,7 +1184,7 @@ class LLMGroundedDiffusionPipeline(
if verbose: if verbose:
logger.info( logger.info(
f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}" f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
) )
finally: finally:
@@ -701,7 +701,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
raise ValueError("`max_tile_size` cannot be None.") raise ValueError("`max_tile_size` cannot be None.")
elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280): elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
raise ValueError( raise ValueError(
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}." f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}."
) )
if tile_gaussian_sigma is None: if tile_gaussian_sigma is None:
raise ValueError("`tile_gaussian_sigma` cannot be None.") raise ValueError("`tile_gaussian_sigma` cannot be None.")
@@ -488,7 +488,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -496,7 +496,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+6 -6
View File
@@ -907,12 +907,12 @@ def create_controller(
# reweight # reweight
if edit_type == "reweight": if edit_type == "reweight":
assert ( assert equalizer_words is not None and equalizer_strengths is not None, (
equalizer_words is not None and equalizer_strengths is not None "To use reweight edit, please specify equalizer_words and equalizer_strengths."
), "To use reweight edit, please specify equalizer_words and equalizer_strengths." )
assert len(equalizer_words) == len( assert len(equalizer_words) == len(equalizer_strengths), (
equalizer_strengths "equalizer_words and equalizer_strengths must be of same length."
), "equalizer_words and equalizer_strengths must be of same length." )
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight( return AttentionReweight(
prompts, prompts,
@@ -1738,7 +1738,7 @@ class StyleAlignedSDXLPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
@@ -689,7 +689,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} " f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of" f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input." " `pipeline.unet` or your `image` input."
) )
@@ -1028,7 +1028,7 @@ class StableDiffusionXL_AE_Pipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -1036,7 +1036,7 @@ class StableDiffusionXL_AE_Pipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -2050,7 +2050,7 @@ class StableDiffusionXL_AE_Pipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
@@ -1578,7 +1578,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
+1 -2
View File
@@ -288,8 +288,7 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
if timesteps[0] >= self.config.num_train_timesteps: if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError( raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:" f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
f" {self.config.num_train_timesteps}."
) )
timesteps = np.array(timesteps, dtype=np.int64) timesteps = np.array(timesteps, dtype=np.int64)
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter # Set alpha parameter
if "lora_down" in kohya_key: if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha' alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict return kohya_ss_state_dict
@@ -901,7 +901,7 @@ def main(args):
unet_ = accelerator.unwrap_model(unet) unet_ = accelerator.unwrap_model(unet)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = { unet_state_dict = {
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
} }
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter # Set alpha parameter
if "lora_down" in kohya_key: if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha' alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict return kohya_ss_state_dict
+5 -3
View File
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total = 0 total = 0
pbar = tqdm(desc="downloading real regularization images", total=num_class_images) pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open( with (
f"{class_data_dir}/images.txt", "w" open(f"{class_data_dir}/caption.txt", "w") as f1,
) as f3: open(f"{class_data_dir}/urls.txt", "w") as f2,
open(f"{class_data_dir}/images.txt", "w") as f3,
):
while total < num_class_images: while total < num_class_images:
images = class_images[count] images = class_images[count]
count += 1 count += 1
@@ -731,18 +731,18 @@ def main(args):
if not class_images_dir.exists(): if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True) class_images_dir.mkdir(parents=True, exist_ok=True)
if args.real_prior: if args.real_prior:
assert ( assert (class_images_dir / "images").exists(), (
class_images_dir / "images" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
len(list((class_images_dir / "images").iterdir())) == args.num_class_images f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert (class_images_dir / "caption.txt").exists(), (
class_images_dir / "caption.txt" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
assert ( assert (class_images_dir / "images.txt").exists(), (
class_images_dir / "images.txt" f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" )
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt") concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt") concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
args.concepts_list[i] = concept args.concepts_list[i] = concept
+1 -1
View File
@@ -1014,7 +1014,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
+1 -1
View File
@@ -982,7 +982,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -1294,7 +1294,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1053,7 +1053,7 @@ def main(args):
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir) lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1064,7 +1064,7 @@ def main(args):
lora_state_dict = SanaPipeline.lora_state_dict(input_dir) lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1355,7 +1355,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -118,7 +118,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} # {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
@@ -1286,7 +1286,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
assert ( assert pipeline.transformer.config.in_channels == initial_channels * 2, (
pipeline.transformer.config.in_channels == initial_channels * 2 f"{pipeline.transformer.config.in_channels=}"
), f"{pipeline.transformer.config.in_channels=}" )
pipeline.to(accelerator.device) pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
@@ -954,7 +954,7 @@ def main(args):
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir) lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
transformer_lora_state_dict = { transformer_lora_state_dict = {
f'{k.replace("transformer.", "")}': v f"{k.replace('transformer.', '')}": v
for k, v in lora_state_dict.items() for k, v in lora_state_dict.items()
if k.startswith("transformer.") and "lora" in k if k.startswith("transformer.") and "lora" in k
} }
+3 -3
View File
@@ -1081,9 +1081,9 @@ class AutoConfig:
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}" f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
) )
pretrained_model_name_or_paths[ pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
pretrained_model_name_or_paths.index(search_word) textual_inversion_path.model_path
] = textual_inversion_path.model_path )
self.load_textual_inversion( self.load_textual_inversion(
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors="pt", return_tensors="pt",
) )
tokens = batch_encoding["input_ids"] tokens = batch_encoding["input_ids"]
assert ( assert torch.count_nonzero(tokens - 49407) == 2, (
torch.count_nonzero(tokens - 49407) == 2 f"String '{string}' maps to more than a single token. Please use another string"
), f"String '{string}' maps to more than a single token. Please use another string" )
return tokens[0, 1] return tokens[0, 1]
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert ( assert H == self.img_size[0] and W == self.img_size[1], (
H == self.img_size[0] and W == self.img_size[1] f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." )
x = self.proj(x).flatten(2).permute(0, 2, 1) x = self.proj(x).flatten(2).permute(0, 2, 1)
return x return x
@@ -619,7 +619,7 @@ def main(args):
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
@@ -803,21 +803,20 @@ def parse_args(input_args=None):
"--control_type", "--control_type",
type=str, type=str,
default="canny", default="canny",
help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."), help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
) )
parser.add_argument( parser.add_argument(
"--transformer_layers_per_block", "--transformer_layers_per_block",
type=str, type=str,
default=None, default=None,
help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."), help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
) )
parser.add_argument( parser.add_argument(
"--old_style_controlnet", "--old_style_controlnet",
action="store_true", action="store_true",
default=False, default=False,
help=( help=(
"Use the old style controlnet, which is a single transformer layer with" "Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
" a single head. Defaults to False."
), ),
) )
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
@@ -683,7 +683,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False): def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.") logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
if is_final_validation: if is_final_validation:
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
@@ -790,7 +790,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -783,7 +783,7 @@ def main(args):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir) lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
File diff suppressed because one or more lines are too long
+18 -23
View File
@@ -26,8 +26,7 @@
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload 2\n", "%autoreload 2\n",
"\n", "\n",
"import torch\n", "from diffusers import StableDiffusionGLIGENPipeline"
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
] ]
}, },
{ {
@@ -36,28 +35,25 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"import diffusers\n", "import diffusers\n",
"from diffusers import (\n", "from diffusers import (\n",
" AutoencoderKL,\n", " AutoencoderKL,\n",
" DDPMScheduler,\n", " DDPMScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n", " EulerDiscreteScheduler,\n",
" UNet2DConditionModel,\n",
")\n", ")\n",
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", "\n",
"\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n", "\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n", "\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n", "tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n", "noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n", "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n", "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n", "# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n", "# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n", "# )\n",
@@ -71,9 +67,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"unet = UNet2DConditionModel.from_pretrained(\n", "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
] ]
}, },
{ {
@@ -108,6 +102,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n",
"\n",
"\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n", "# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n", "# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n", "\n",
@@ -117,10 +114,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n", "# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n", "# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n", "\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n", "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n", "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n", "\n",
"boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n", "boxes = boxes / 512\n",
@@ -166,7 +161,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"diffusers.utils.make_image_grid(images, 4, len(images)//4)" "diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
] ]
}, },
{ {
@@ -179,7 +174,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "densecaption", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@@ -197,5 +192,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 4
} }
@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
""" """
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix. Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
""" """
import argparse import argparse
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match # Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",") instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",") instance_prompt = args.instance_prompt.split(",")
assert all( assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)] "Instance data dir and prompt inputs are not of the same length."
), "Instance data dir and prompt inputs are not of the same length." )
if args.with_prior_preservation: if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",") class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None) negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts args.validation_negative_prompt = negative_validation_prompts
assert num_of_validation_prompts == len( assert num_of_validation_prompts == len(negative_validation_prompts), (
negative_validation_prompts "The length of negative prompts for validation is greater than the number of validation prompts."
), "The length of negative prompts for validation is greater than the number of validation prompts." )
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
@@ -830,9 +830,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator) index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -886,9 +886,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -663,8 +663,7 @@ class PromptDiffusionPipeline(
self.check_image(image, prompt, prompt_embeds) self.check_image(image, prompt, prompt_embeds)
else: else:
raise ValueError( raise ValueError(
f"You have passed a list of images of length {len(image_pair)}." f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
f"Make sure the list size equals to two."
) )
# Check `controlnet_conditioning_scale` # Check `controlnet_conditioning_scale`
@@ -173,7 +173,7 @@ class TrainSD:
if not dataloader_exception: if not dataloader_exception:
xm.wait_device_ops() xm.wait_device_ops()
total_time = time.time() - last_time total_time = time.time() - last_time
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
else: else:
print("dataloader exception happen, skip result") print("dataloader exception happen, skip result")
return return
@@ -622,7 +622,7 @@ def main(args):
num_devices_per_host = num_devices // num_hosts num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal(): if xm.is_master_ordinal():
print("***** Running training *****") print("***** Running training *****")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }") print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print( print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}" f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
) )
@@ -1057,7 +1057,7 @@ def main(args):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
@@ -1021,7 +1021,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
@@ -118,7 +118,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} # {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
<Gallery /> <Gallery />
@@ -1336,7 +1336,7 @@ def main(args):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -750,7 +750,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -767,7 +767,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None: if incompatible_keys is not None:
@@ -910,9 +910,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -965,12 +965,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
index_no_updates orig_embeds_params[index_no_updates]
] = orig_embeds_params[index_no_updates] )
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
index_no_updates_2 orig_embeds_params_2[index_no_updates_2]
] = orig_embeds_params_2[index_no_updates_2] )
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
+3 -3
View File
@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path} --model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1 --checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir} --output_dir {tmpdir}
--seed=0 --seed=0
""".split() """.split()
@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
--model_config_name_or_path {vqmodel_config_path} --model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1 --checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir} --output_dir {tmpdir}
--use_ema --use_ema
--seed=0 --seed=0
@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
--discriminator_config_name_or_path {discriminator_config_path} --discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir} --output_dir {tmpdir}
--checkpointing_steps=2 --checkpointing_steps=2
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2 --checkpoints_total_limit=2
--seed=0 --seed=0
""".split() """.split()
+6 -6
View File
@@ -653,15 +653,15 @@ def main():
try: try:
# Gets the resolution of the timm transformation after centercrop # Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1] timm_centercrop_transform = timm_transform.transforms[1]
assert isinstance( assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
timm_centercrop_transform, transforms.CenterCrop f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." )
timm_model_resolution = timm_centercrop_transform.size[0] timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization # Gets final normalization
timm_model_normalization = timm_transform.transforms[-1] timm_model_normalization = timm_transform.transforms[-1]
assert isinstance( assert isinstance(timm_model_normalization, transforms.Normalize), (
timm_model_normalization, transforms.Normalize f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." )
except AssertionError as e: except AssertionError as e:
raise NotImplementedError(e) raise NotImplementedError(e)
# Enable flash attention if asked # Enable flash attention if asked
+1 -1
View File
@@ -3,7 +3,7 @@ line-length = 119
[tool.ruff.lint] [tool.ruff.lint]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"] ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files. # Ignore import violations in all `__init__.py` files.
+1 -1
View File
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all() # assert (old_output == new_output).all()
print("skipping full vae equivalence check") print("skipping full vae equivalence check")
print(f"vae full diff { (old_output - new_output).float().abs().sum()}") print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
return new_vae return new_vae
+2 -2
View File
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1: if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0" new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.1" old_prefix = f"output_blocks.{current_layer - 1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D": elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1): for j in range(layers_per_block + 1):
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1: if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0" new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.2" old_prefix = f"output_blocks.{current_layer - 1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
@@ -261,9 +261,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0] model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path): if not os.path.isfile(args.model_path):
assert ( assert model_name == args.model_path, (
model_name == args.model_path f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" )
args.model_path = download(model_name) args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"] sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -290,9 +290,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items(): for key, value in renamed_state_dict.items():
assert ( assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" )
if key == "time_proj.weight": if key == "time_proj.weight":
value = value.squeeze() value = value.squeeze()
@@ -52,18 +52,18 @@ for i in range(3):
for j in range(2): for j in range(2):
# loop over resnets/attentions for downblocks # loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0: if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4): for j in range(4):
# loop over resnets/attentions for upblocks # loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2: if i < 2:
@@ -75,12 +75,12 @@ for i in range(3):
if i < 3: if i < 3:
# no downsample in down_blocks.3 # no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3 # no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv.")) unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2): for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}." hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}." sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -137,20 +137,20 @@ for i in range(4):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample." sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets # up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd # also, up blocks in hf are numbered in reverse from sd
for j in range(3): for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder # this part accounts for mid blocks in both the encoder and the decoder
for i in range(2): for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}." hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}." sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -47,36 +47,36 @@ for i in range(4):
for j in range(2): for j in range(2):
# loop over resnets/attentions for downblocks # loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3: if i < 3:
# no attention layers in down_blocks.3 # no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3): for j in range(3):
# loop over resnets/attentions for upblocks # loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0: if i > 0:
# no attention layers in up_blocks.0 # no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3: if i < 3:
# no downsample in down_blocks.3 # no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3 # no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0." hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2): for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}." hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}." sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -133,20 +133,20 @@ for i in range(4):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample." sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets # up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd # also, up blocks in hf are numbered in reverse from sd
for j in range(3): for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder # this part accounts for mid blocks in both the encoder and the decoder
for i in range(2): for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}." hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}." sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config( model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer" "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
) )
model_config[ model_config["use_style_cond_and_image_meta_size"] = (
"use_style_cond_and_image_meta_size" args.use_style_cond_and_image_meta_size
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False ) ### version <= v1.1: True; version >= v1.2: False
print(model_config) print(model_config)
for key in state_dict: for key in state_dict:
+4 -5
View File
@@ -13,15 +13,14 @@ def main(args):
state_dict = state_dict[args.load_key] state_dict = state_dict[args.load_key]
except KeyError: except KeyError:
raise KeyError( raise KeyError(
f"{args.load_key} not found in the checkpoint." f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
f"Please load from the following keys:{state_dict.keys()}"
) )
device = "cuda" device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer") model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
model_config[ model_config["use_style_cond_and_image_meta_size"] = (
"use_style_cond_and_image_meta_size" args.use_style_cond_and_image_meta_size
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False ) ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim # input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict: for key in state_dict:
+5 -5
View File
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}" self_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx }" cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_index = 1 if not attention.add_self_attention else 2 cross_attention_index = 1 if not attention.add_self_attention else 2
idx = ( idx = (
n * attention_idx + cross_attention_index n * attention_idx + cross_attention_index
if block_type == "up" if block_type == "up"
else n * attention_idx + cross_attention_index + 1 else n * attention_idx + cross_attention_index + 1
) )
cross_attention_prefix = f"{block_prefix}.{idx }" cross_attention_prefix = f"{block_prefix}.{idx}"
diffusers_checkpoint.update( diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint( cross_attn_to_diffusers_checkpoint(
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"] block_out_channels = original_config["channels"]
assert ( assert len(set(original_config["depths"])) == 1, (
len(set(original_config["depths"])) == 1 "UNet2DConditionModel currently do not support blocks with different number of layers"
), "UNet2DConditionModel currently do not support blocks with different number of layers" )
layers_per_block = original_config["depths"][0] layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"] class_labels_dim = original_config["mapping_cond_dim"]
+60 -58
View File
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D) # Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3 for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.0.weight" f"blocks.0.{i + 1}.stack.0.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.0.bias" f"blocks.0.{i + 1}.stack.0.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.2.weight" f"blocks.0.{i + 1}.stack.2.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.2.bias" f"blocks.0.{i + 1}.stack.2.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.3.weight" f"blocks.0.{i + 1}.stack.3.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.3.bias" f"blocks.0.{i + 1}.stack.3.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.5.weight" f"blocks.0.{i + 1}.stack.5.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.5.bias" f"blocks.0.{i + 1}.stack.5.bias"
) )
# Convert up_blocks (MochiUpBlock3D) # Convert up_blocks (MochiUpBlock3D)
@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3): for block in range(3):
for i in range(down_block_layers[block]): for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.0.weight" f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.0.bias" f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.2.weight" f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.2.bias" f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.3.weight" f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.3.bias" f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.5.weight" f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.5.bias" f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.proj.weight" f"blocks.{block + 1}.proj.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.proj.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D) # Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3 for i in range(3): # layers_per_block[0] = 3
@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D) # Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3 for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.0.weight" f"layers.{i + 1}.stack.0.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.0.bias" f"layers.{i + 1}.stack.0.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.2.weight" f"layers.{i + 1}.stack.2.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.2.bias" f"layers.{i + 1}.stack.2.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.3.weight" f"layers.{i + 1}.stack.3.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.3.bias" f"layers.{i + 1}.stack.3.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.5.weight" f"layers.{i + 1}.stack.5.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.5.bias" f"layers.{i + 1}.stack.5.bias"
) )
# Convert down_blocks (MochiDownBlock3D) # Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3] down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3): for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.0.weight" f"layers.{block + 4}.layers.0.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.0.bias" f"layers.{block + 4}.layers.0.bias"
) )
for i in range(down_block_layers[block]): for i in range(down_block_layers[block]):
# Convert resnets # Convert resnets
new_state_dict[ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.0.bias" f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.2.weight" f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.2.bias" f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
) )
new_state_dict[
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.3.bias" f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.5.weight" f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.5.bias" f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
) )
# Convert attentions # Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight") qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0) q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
) )
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
) )
# Convert block_out (MochiMidBlock3D) # Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3 for i in range(3): # layers_per_block[-1] = 3
# Convert resnets # Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.0.weight" f"layers.{i + 7}.stack.0.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.0.bias" f"layers.{i + 7}.stack.0.bias"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.2.weight" f"layers.{i + 7}.stack.2.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.2.bias" f"layers.{i + 7}.stack.2.bias"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.3.weight" f"layers.{i + 7}.stack.3.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.3.bias" f"layers.{i + 7}.stack.3.bias"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.5.weight" f"layers.{i + 7}.stack.5.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.5.bias" f"layers.{i + 7}.stack.5.bias"
) )
# Convert attentions # Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight") qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0) q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.attn.out.weight" f"layers.{i + 7}.attn_block.attn.out.weight"
) )
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.attn.out.bias" f"layers.{i + 7}.attn_block.attn.out.bias"
) )
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.norm.weight" f"layers.{i + 7}.attn_block.norm.weight"
) )
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.norm.bias" f"layers.{i + 7}.attn_block.norm.bias"
) )
# Convert output layers # Convert output layers
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list # replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1) sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key): elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1)) projecton_layer = int(re.match(text_projection_pattern, key).group(1))
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list # replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1) sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key): elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1)) projecton_layer = int(re.match(text_projection_pattern, key).group(1))
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list # replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1) sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key): elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1)) projecton_layer = int(re.match(text_projection_pattern, key).group(1))
+9 -9
View File
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer # get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0]) idx = int(new_key.split("coder.layers.")[1].split(".")[0])
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
if "encoder" in new_key: if "encoder" in new_key:
for i in range(3): for i in range(3):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
else: else:
for i in range(2, 5): for i in range(2, 5):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta") new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha") new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_") new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1: if idx == num_autoencoder_layers + 1:
new_key = new_key.replace(f"block.{idx-1}", "snake1") new_key = new_key.replace(f"block.{idx - 1}", "snake1")
elif idx == num_autoencoder_layers + 2: elif idx == num_autoencoder_layers + 2:
new_key = new_key.replace(f"block.{idx-1}", "conv2") new_key = new_key.replace(f"block.{idx - 1}", "conv2")
else: else:
new_key = new_key new_key = new_key
+6 -6
View File
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor # TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[ new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] )
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
) )
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[ new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values(): if ["conv.bias", "conv.weight"] in output_block_list.values():
+12 -12
View File
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
def vqvae_model_from_original_config(original_config): def vqvae_model_from_original_config(original_config):
assert ( assert original_config["target"] in PORTED_VQVAES, (
original_config["target"] in PORTED_VQVAES f"{original_config['target']} has not yet been ported to diffusers."
), f"{original_config['target']} has not yet been ported to diffusers." )
original_config = original_config["params"] original_config = original_config["params"]
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
def transformer_model_from_original_config( def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config original_diffusion_config, original_transformer_config, original_content_embedding_config
): ):
assert ( assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
original_diffusion_config["target"] in PORTED_DIFFUSIONS f"{original_diffusion_config['target']} has not yet been ported to diffusers."
), f"{original_diffusion_config['target']} has not yet been ported to diffusers." )
assert ( assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
original_transformer_config["target"] in PORTED_TRANSFORMERS f"{original_transformer_config['target']} has not yet been ported to diffusers."
), f"{original_transformer_config['target']} has not yet been ported to diffusers." )
assert ( assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers." )
original_diffusion_config = original_diffusion_config["params"] original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"] original_transformer_config = original_transformer_config["params"]
+1 -1
View File
@@ -122,7 +122,7 @@ _deps = [
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"python>=3.8.0", "python>=3.8.0",
"ruff==0.1.5", "ruff==0.9.10",
"safetensors>=0.3.1", "safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19", "GitPython<3.1.19",
+1 -1
View File
@@ -29,7 +29,7 @@ deps = {
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0", "python": "python>=3.8.0",
"ruff": "ruff==0.1.5", "ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1", "safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
+1 -2
View File
@@ -295,8 +295,7 @@ class IPAdapterMixin:
): ):
if len(scale_configs) != len(attn_processor.scale): if len(scale_configs) != len(attn_processor.scale):
raise ValueError( raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to " f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
f"{len(attn_processor.scale)} IP-Adapter."
) )
elif len(scale_configs) == 1: elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale) scale_configs = scale_configs * len(attn_processor.scale)
+33 -33
View File
@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present. # Store DoRA scale if present.
if dora_present_in_unet: if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
# Handle text encoder LoRAs. # Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
) )
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
elif lora_name.startswith("lora_te2_"): elif lora_name.startswith("lora_te2_"):
te2_state_dict[ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
# Store alpha if present. # Store alpha if present.
if lora_name_alpha in state_dict: if lora_name_alpha in state_dict:
@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
for lora_key in ["lora_A", "lora_B"]: for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in ## time_text_embed.timestep_embedder <- time_in
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") )
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") )
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") )
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") )
## time_text_embed.text_embedder <- vector_in ## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
# guidance # guidance
has_guidance = any("guidance" in k for k in original_state_dict) has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance: if has_guidance:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") )
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") )
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") )
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") )
# context_embedder # context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
+1 -1
View File
@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
@@ -41,7 +42,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["cache_utils"] = ["CacheMixin"] _import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
+1 -1
View File
@@ -205,7 +205,7 @@ def load_state_dict(
) from e ) from e
except (UnicodeDecodeError, ValueError): except (UnicodeDecodeError, ValueError):
raise OSError( raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
) )
@@ -211,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
def _init_vectorized_inputs(self, norm_type): def _init_vectorized_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert ( assert self.config.num_vector_embeds is not None, (
self.config.num_vector_embeds is not None "Transformer2DModel over discrete input must provide num_embed"
), "Transformer2DModel over discrete input must provide num_embed" )
self.height = self.config.sample_size self.height = self.config.sample_size
self.width = self.config.sample_size self.width = self.config.sample_size
@@ -791,7 +791,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
if transcription is None: if transcription is None:
if self.text_encoder_2.config.model_type == "vits": if self.text_encoder_2.config.model_type == "vits":
raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription") raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
elif transcription is not None and ( elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list) not isinstance(transcription, str) and not isinstance(transcription, list)
): ):
@@ -657,7 +657,7 @@ class StableDiffusionControlNetInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -665,7 +665,7 @@ class StableDiffusionControlNetInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple # `prompt` needs more sophisticated handling when there are multiple
# conditionings. # conditionings.
@@ -1130,7 +1130,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input." " `pipeline.transformer` or your `mask_image` or `image` input."
) )
@@ -507,7 +507,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -515,7 +515,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -574,7 +574,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -582,7 +582,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+3 -3
View File
@@ -341,9 +341,9 @@ class AnimateDiffFreeNoiseMixin:
start_tensor = negative_prompt_embeds[i].unsqueeze(0) start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0) end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
negative_prompt_interpolation_embeds[ negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
start_frame : end_frame + 1 self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor) )
prompt_embeds = prompt_interpolation_embeds prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds negative_prompt_embeds = negative_prompt_interpolation_embeds
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
""" """
_load_connected_pipes = True _load_connected_pipes = True
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq" model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
_exclude_from_cpu_offload = ["prior_prior"] _exclude_from_cpu_offload = ["prior_prior"]
def __init__( def __init__(
@@ -579,7 +579,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
@@ -95,13 +95,13 @@ class OmniGenMultiModalProcessor:
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids)) unique_image_ids = sorted(set(image_ids))
assert unique_image_ids == list( assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
range(1, len(unique_image_ids) + 1) f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" )
# total images must be the same as the number of image tags # total images must be the same as the number of image tags
assert ( assert len(unique_image_ids) == len(input_images), (
len(unique_image_ids) == len(input_images) f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" )
input_images = [input_images[x - 1] for x in image_ids] input_images = [input_images[x - 1] for x in image_ids]
@@ -604,7 +604,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -612,7 +612,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple # `prompt` needs more sophisticated handling when there are multiple
# conditionings. # conditionings.
@@ -1340,7 +1340,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
@@ -683,7 +683,7 @@ class StableDiffusionPAGInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -691,7 +691,7 @@ class StableDiffusionPAGInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -1191,7 +1191,7 @@ class StableDiffusionPAGInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
@@ -737,7 +737,7 @@ class StableDiffusionXLPAGInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
@@ -745,7 +745,7 @@ class StableDiffusionXLPAGInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@@ -1509,7 +1509,7 @@ class StableDiffusionXLPAGInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
@@ -575,7 +575,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
@@ -323,9 +323,7 @@ def maybe_raise_or_warn(
model_cls = unwrapped_sub_model.__class__ model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj): if not issubclass(model_cls, expected_class_obj):
raise ValueError( raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
)
else: else:
logger.warning( logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+6 -6
View File
@@ -983,9 +983,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
fields = torch.cat(fields, dim=1) fields = torch.cat(fields, dim=1)
fields = fields.float() fields = fields.float()
assert ( assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
len(fields.shape) == 3 and fields.shape[-1] == 1 f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" )
fields = fields.reshape(1, *([grid_size] * 3)) fields = fields.reshape(1, *([grid_size] * 3))
@@ -1039,9 +1039,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
textures = textures.float() textures = textures.float()
# 3.3 augument the mesh with texture data # 3.3 augument the mesh with texture data
assert len(textures.shape) == 3 and textures.shape[-1] == len( assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
texture_channels f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" )
for m, texture in zip(raw_meshes, textures): for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)] texture = texture[: len(m.verts)]
@@ -584,7 +584,7 @@ class StableAudioPipeline(DiffusionPipeline):
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError( raise ValueError(
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
) )
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
@@ -335,7 +335,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
@@ -475,7 +475,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
"Incorrect configuration settings! The config of `pipeline.unet` expects" "Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )

Some files were not shown because too many files have changed in this diff Show More