Compare commits
6 Commits
chroma
...
ruff-update
| Author | SHA1 | Date | |
|---|---|---|---|
| b365801c57 | |||
| 644147a198 | |||
| c852f239f2 | |||
| be861e236f | |||
| 2d744f0707 | |||
| 41c7e72d44 |
@@ -839,9 +839,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -1605,7 +1605,7 @@ def main(args):
|
||||
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -200,7 +200,8 @@ Special VAE used for training: {vae_path}.
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
lora,
|
||||
"template:sd-lora" "stable-diffusion",
|
||||
"template:sd-lora",
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
@@ -724,9 +725,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -746,9 +747,9 @@ class TokenEmbeddingsHandler:
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[
|
||||
f"original_embeddings_{idx}"
|
||||
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -1322,7 +1323,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -890,9 +890,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -912,9 +912,9 @@ class TokenEmbeddingsHandler:
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[
|
||||
f"original_embeddings_{idx}"
|
||||
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -1647,7 +1647,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -720,7 +720,7 @@ def main(args):
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num training steps = {args.max_train_steps}")
|
||||
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
|
||||
|
||||
@@ -1138,7 +1138,7 @@ def main(args):
|
||||
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1159,7 +1159,7 @@ def main(args):
|
||||
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1103,7 +1103,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `default_mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -686,7 +686,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -362,7 +362,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
|
||||
f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1184,7 +1184,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
|
||||
f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
|
||||
)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -701,7 +701,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
|
||||
raise ValueError("`max_tile_size` cannot be None.")
|
||||
elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
|
||||
raise ValueError(
|
||||
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}."
|
||||
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}."
|
||||
)
|
||||
if tile_gaussian_sigma is None:
|
||||
raise ValueError("`tile_gaussian_sigma` cannot be None.")
|
||||
|
||||
@@ -488,7 +488,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -496,7 +496,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -907,12 +907,12 @@ def create_controller(
|
||||
|
||||
# reweight
|
||||
if edit_type == "reweight":
|
||||
assert (
|
||||
equalizer_words is not None and equalizer_strengths is not None
|
||||
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
|
||||
assert len(equalizer_words) == len(
|
||||
equalizer_strengths
|
||||
), "equalizer_words and equalizer_strengths must be of same length."
|
||||
assert equalizer_words is not None and equalizer_strengths is not None, (
|
||||
"To use reweight edit, please specify equalizer_words and equalizer_strengths."
|
||||
)
|
||||
assert len(equalizer_words) == len(equalizer_strengths), (
|
||||
"equalizer_words and equalizer_strengths must be of same length."
|
||||
)
|
||||
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
|
||||
return AttentionReweight(
|
||||
prompts,
|
||||
|
||||
@@ -1738,7 +1738,7 @@ class StyleAlignedSDXLPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -689,7 +689,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -1028,7 +1028,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -1036,7 +1036,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -2050,7 +2050,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -1578,7 +1578,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -288,8 +288,7 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
|
||||
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
|
||||
|
||||
# Set alpha parameter
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
@@ -901,7 +901,7 @@ def main(args):
|
||||
unet_ = accelerator.unwrap_model(unet)
|
||||
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
|
||||
unet_state_dict = {
|
||||
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
|
||||
|
||||
# Set alpha parameter
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
|
||||
total = 0
|
||||
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
|
||||
|
||||
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
|
||||
f"{class_data_dir}/images.txt", "w"
|
||||
) as f3:
|
||||
with (
|
||||
open(f"{class_data_dir}/caption.txt", "w") as f1,
|
||||
open(f"{class_data_dir}/urls.txt", "w") as f2,
|
||||
open(f"{class_data_dir}/images.txt", "w") as f3,
|
||||
):
|
||||
while total < num_class_images:
|
||||
images = class_images[count]
|
||||
count += 1
|
||||
|
||||
@@ -731,18 +731,18 @@ def main(args):
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
if args.real_prior:
|
||||
assert (
|
||||
class_images_dir / "images"
|
||||
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (
|
||||
len(list((class_images_dir / "images").iterdir())) == args.num_class_images
|
||||
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (
|
||||
class_images_dir / "caption.txt"
|
||||
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (
|
||||
class_images_dir / "images.txt"
|
||||
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (class_images_dir / "images").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (class_images_dir / "caption.txt").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (class_images_dir / "images.txt").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
|
||||
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
|
||||
args.concepts_list[i] = concept
|
||||
|
||||
@@ -1014,7 +1014,7 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
|
||||
@@ -982,7 +982,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
|
||||
@@ -1294,7 +1294,7 @@ def main(args):
|
||||
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1053,7 +1053,7 @@ def main(args):
|
||||
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1064,7 +1064,7 @@ def main(args):
|
||||
lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1355,7 +1355,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -118,7 +118,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -1286,7 +1286,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
assert (
|
||||
pipeline.transformer.config.in_channels == initial_channels * 2
|
||||
), f"{pipeline.transformer.config.in_channels=}"
|
||||
assert pipeline.transformer.config.in_channels == initial_channels * 2, (
|
||||
f"{pipeline.transformer.config.in_channels=}"
|
||||
)
|
||||
|
||||
pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -954,7 +954,7 @@ def main(args):
|
||||
|
||||
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
|
||||
transformer_lora_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v
|
||||
f"{k.replace('transformer.', '')}": v
|
||||
for k, v in lora_state_dict.items()
|
||||
if k.startswith("transformer.") and "lora" in k
|
||||
}
|
||||
|
||||
@@ -1081,9 +1081,9 @@ class AutoConfig:
|
||||
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
|
||||
)
|
||||
|
||||
pretrained_model_name_or_paths[
|
||||
pretrained_model_name_or_paths.index(search_word)
|
||||
] = textual_inversion_path.model_path
|
||||
pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
|
||||
textual_inversion_path.model_path
|
||||
)
|
||||
|
||||
self.load_textual_inversion(
|
||||
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
|
||||
|
||||
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"]
|
||||
assert (
|
||||
torch.count_nonzero(tokens - 49407) == 2
|
||||
), f"String '{string}' maps to more than a single token. Please use another string"
|
||||
assert torch.count_nonzero(tokens - 49407) == 2, (
|
||||
f"String '{string}' maps to more than a single token. Please use another string"
|
||||
)
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
|
||||
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
assert H == self.img_size[0] and W == self.img_size[1], (
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
)
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
@@ -619,7 +619,7 @@ def main(args):
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
|
||||
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
@@ -803,21 +803,20 @@ def parse_args(input_args=None):
|
||||
"--control_type",
|
||||
type=str,
|
||||
default="canny",
|
||||
help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
|
||||
help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformer_layers_per_block",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
|
||||
help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_style_controlnet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Use the old style controlnet, which is a single transformer layer with"
|
||||
" a single head. Defaults to False."
|
||||
"Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
|
||||
|
||||
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
|
||||
|
||||
|
||||
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
|
||||
if is_final_validation:
|
||||
if args.mixed_precision == "fp16":
|
||||
|
||||
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
|
||||
|
||||
|
||||
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
|
||||
if is_final_validation:
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -683,7 +683,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
|
||||
|
||||
|
||||
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
|
||||
if is_final_validation:
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -790,7 +790,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
+1
-1
@@ -783,7 +783,7 @@ def main(args):
|
||||
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -26,8 +26,7 @@
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
|
||||
"from diffusers import StableDiffusionGLIGENPipeline"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -36,28 +35,25 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from transformers import CLIPTextModel, CLIPTokenizer\n",
|
||||
"\n",
|
||||
"import diffusers\n",
|
||||
"from diffusers import (\n",
|
||||
" AutoencoderKL,\n",
|
||||
" DDPMScheduler,\n",
|
||||
" UNet2DConditionModel,\n",
|
||||
" UniPCMultistepScheduler,\n",
|
||||
" EulerDiscreteScheduler,\n",
|
||||
" UNet2DConditionModel,\n",
|
||||
")\n",
|
||||
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
|
||||
"\n",
|
||||
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
|
||||
"pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
|
||||
"\n",
|
||||
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
|
||||
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
|
||||
"text_encoder = CLIPTextModel.from_pretrained(\n",
|
||||
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
|
||||
")\n",
|
||||
"vae = AutoencoderKL.from_pretrained(\n",
|
||||
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
|
||||
")\n",
|
||||
"text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
|
||||
"vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
|
||||
"# unet = UNet2DConditionModel.from_pretrained(\n",
|
||||
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
|
||||
"# )\n",
|
||||
@@ -71,9 +67,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"unet = UNet2DConditionModel.from_pretrained(\n",
|
||||
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
|
||||
")"
|
||||
"unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -108,6 +102,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
|
||||
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
|
||||
"\n",
|
||||
@@ -117,10 +114,8 @@
|
||||
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
|
||||
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
|
||||
"\n",
|
||||
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
|
||||
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
|
||||
"gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
|
||||
"\n",
|
||||
"boxes = np.array([x[1] for x in gen_boxes])\n",
|
||||
"boxes = boxes / 512\n",
|
||||
@@ -166,7 +161,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"diffusers.utils.make_image_grid(images, 4, len(images)//4)"
|
||||
"diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -179,7 +174,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "densecaption",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -197,5 +192,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
|
||||
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
|
||||
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
+6
-6
@@ -763,9 +763,9 @@ def main(args):
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
|
||||
"Instance data dir and prompt inputs are not of the same length."
|
||||
)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
@@ -788,9 +788,9 @@ def main(args):
|
||||
negative_validation_prompts.append(None)
|
||||
args.validation_negative_prompt = negative_validation_prompts
|
||||
|
||||
assert num_of_validation_prompts == len(
|
||||
negative_validation_prompts
|
||||
), "The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
assert num_of_validation_prompts == len(negative_validation_prompts), (
|
||||
"The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
)
|
||||
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
|
||||
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
|
||||
|
||||
|
||||
@@ -830,9 +830,9 @@ def main():
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = get_mask(tokenizer, accelerator)
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -886,9 +886,9 @@ def main():
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -663,8 +663,7 @@ class PromptDiffusionPipeline(
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"You have passed a list of images of length {len(image_pair)}."
|
||||
f"Make sure the list size equals to two."
|
||||
f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
|
||||
)
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
|
||||
+2
-2
@@ -173,7 +173,7 @@ class TrainSD:
|
||||
if not dataloader_exception:
|
||||
xm.wait_device_ops()
|
||||
total_time = time.time() - last_time
|
||||
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
|
||||
print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
|
||||
else:
|
||||
print("dataloader exception happen, skip result")
|
||||
return
|
||||
@@ -622,7 +622,7 @@ def main(args):
|
||||
num_devices_per_host = num_devices // num_hosts
|
||||
if xm.is_master_ordinal():
|
||||
print("***** Running training *****")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
|
||||
print(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
|
||||
)
|
||||
|
||||
+1
-1
@@ -1057,7 +1057,7 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
|
||||
+1
-1
@@ -1021,7 +1021,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
|
||||
+2
-2
@@ -118,7 +118,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -1336,7 +1336,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
+1
-1
@@ -750,7 +750,7 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -765,7 +765,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -767,7 +767,7 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -910,9 +910,9 @@ def main():
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -965,12 +965,12 @@ def main():
|
||||
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
|
||||
index_no_updates_2
|
||||
] = orig_embeds_params_2[index_no_updates_2]
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
|
||||
orig_embeds_params_2[index_no_updates_2]
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
|
||||
--output_dir {tmpdir}
|
||||
--use_ema
|
||||
--seed=0
|
||||
@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
|
||||
--checkpoints_total_limit=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
@@ -653,15 +653,15 @@ def main():
|
||||
try:
|
||||
# Gets the resolution of the timm transformation after centercrop
|
||||
timm_centercrop_transform = timm_transform.transforms[1]
|
||||
assert isinstance(
|
||||
timm_centercrop_transform, transforms.CenterCrop
|
||||
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
|
||||
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
)
|
||||
timm_model_resolution = timm_centercrop_transform.size[0]
|
||||
# Gets final normalization
|
||||
timm_model_normalization = timm_transform.transforms[-1]
|
||||
assert isinstance(
|
||||
timm_model_normalization, transforms.Normalize
|
||||
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
assert isinstance(timm_model_normalization, transforms.Normalize), (
|
||||
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise NotImplementedError(e)
|
||||
# Enable flash attention if asked
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ line-length = 119
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Never enforce `E501` (line length violations).
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||
ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
|
||||
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
|
||||
|
||||
# assert (old_output == new_output).all()
|
||||
print("skipping full vae equivalence check")
|
||||
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
|
||||
print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
|
||||
|
||||
return new_vae
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.1"
|
||||
old_prefix = f"output_blocks.{current_layer - 1}.1"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
elif layer_type == "AttnUpBlock2D":
|
||||
for j in range(layers_per_block + 1):
|
||||
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.2"
|
||||
old_prefix = f"output_blocks.{current_layer - 1}.2"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
||||
|
||||
@@ -261,9 +261,9 @@ def main(args):
|
||||
|
||||
model_name = args.model_path.split("/")[-1].split(".")[0]
|
||||
if not os.path.isfile(args.model_path):
|
||||
assert (
|
||||
model_name == args.model_path
|
||||
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
assert model_name == args.model_path, (
|
||||
f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
)
|
||||
args.model_path = download(model_name)
|
||||
|
||||
sample_rate = MODELS_MAP[model_name]["sample_rate"]
|
||||
@@ -290,9 +290,9 @@ def main(args):
|
||||
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
|
||||
|
||||
for key, value in renamed_state_dict.items():
|
||||
assert (
|
||||
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
|
||||
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
|
||||
f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
)
|
||||
if key == "time_proj.weight":
|
||||
value = value.squeeze()
|
||||
|
||||
|
||||
@@ -52,18 +52,18 @@ for i in range(3):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(4):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i < 2:
|
||||
@@ -75,12 +75,12 @@ for i in range(3):
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
||||
|
||||
@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
@@ -137,20 +137,20 @@ for i in range(4):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
|
||||
@@ -47,36 +47,36 @@ for i in range(4):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
@@ -133,20 +133,20 @@ for i in range(4):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ def main(args):
|
||||
model_config = HunyuanDiT2DControlNetModel.load_config(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
||||
)
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
print(model_config)
|
||||
|
||||
for key in state_dict:
|
||||
|
||||
@@ -13,15 +13,14 @@ def main(args):
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
f"Please load from the following keys:{state_dict.keys()}"
|
||||
f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
|
||||
# input_size -> sample_size, text_dim -> cross_attention_dim
|
||||
for key in state_dict:
|
||||
|
||||
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
|
||||
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
|
||||
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
|
||||
self_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx }"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_index = 1 if not attention.add_self_attention else 2
|
||||
idx = (
|
||||
n * attention_idx + cross_attention_index
|
||||
if block_type == "up"
|
||||
else n * attention_idx + cross_attention_index + 1
|
||||
)
|
||||
cross_attention_prefix = f"{block_prefix}.{idx }"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
cross_attn_to_diffusers_checkpoint(
|
||||
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
|
||||
|
||||
block_out_channels = original_config["channels"]
|
||||
|
||||
assert (
|
||||
len(set(original_config["depths"])) == 1
|
||||
), "UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
assert len(set(original_config["depths"])) == 1, (
|
||||
"UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
)
|
||||
layers_per_block = original_config["depths"][0]
|
||||
|
||||
class_labels_dim = original_config["mapping_cond_dim"]
|
||||
|
||||
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.0.weight"
|
||||
f"blocks.0.{i + 1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.0.bias"
|
||||
f"blocks.0.{i + 1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.2.weight"
|
||||
f"blocks.0.{i + 1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.2.bias"
|
||||
f"blocks.0.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.3.weight"
|
||||
f"blocks.0.{i + 1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.3.bias"
|
||||
f"blocks.0.{i + 1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.5.weight"
|
||||
f"blocks.0.{i + 1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.5.bias"
|
||||
f"blocks.0.{i + 1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert up_blocks (MochiUpBlock3D)
|
||||
@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
for block in range(3):
|
||||
for i in range(down_block_layers[block]):
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.proj.weight"
|
||||
f"blocks.{block + 1}.proj.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.proj.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.0.weight"
|
||||
f"layers.{i + 1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.0.bias"
|
||||
f"layers.{i + 1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.2.weight"
|
||||
f"layers.{i + 1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.2.bias"
|
||||
f"layers.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.3.weight"
|
||||
f"layers.{i + 1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.3.bias"
|
||||
f"layers.{i + 1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.5.weight"
|
||||
f"layers.{i + 1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.5.bias"
|
||||
f"layers.{i + 1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert down_blocks (MochiDownBlock3D)
|
||||
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
|
||||
for block in range(3):
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.0.weight"
|
||||
f"layers.{block + 4}.layers.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.0.bias"
|
||||
f"layers.{block + 4}.layers.0.bias"
|
||||
)
|
||||
|
||||
for i in range(down_block_layers[block]):
|
||||
# Convert resnets
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.0.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.3.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
# Convert resnets
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.0.weight"
|
||||
f"layers.{i + 7}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.0.bias"
|
||||
f"layers.{i + 7}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.2.weight"
|
||||
f"layers.{i + 7}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.2.bias"
|
||||
f"layers.{i + 7}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.3.weight"
|
||||
f"layers.{i + 7}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.3.bias"
|
||||
f"layers.{i + 7}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.5.weight"
|
||||
f"layers.{i + 7}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.5.bias"
|
||||
f"layers.{i + 7}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.attn.out.weight"
|
||||
f"layers.{i + 7}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.attn.out.bias"
|
||||
f"layers.{i + 7}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.norm.weight"
|
||||
f"layers.{i + 7}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.norm.bias"
|
||||
f"layers.{i + 7}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert output layers
|
||||
|
||||
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
|
||||
# get idx of the layer
|
||||
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
|
||||
|
||||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
|
||||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
|
||||
|
||||
if "encoder" in new_key:
|
||||
for i in range(3):
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
|
||||
else:
|
||||
for i in range(2, 5):
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
|
||||
|
||||
new_key = new_key.replace("layers.0.beta", "snake1.beta")
|
||||
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
|
||||
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
|
||||
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
|
||||
|
||||
if idx == num_autoencoder_layers + 1:
|
||||
new_key = new_key.replace(f"block.{idx-1}", "snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}", "snake1")
|
||||
elif idx == num_autoencoder_layers + 2:
|
||||
new_key = new_key.replace(f"block.{idx-1}", "conv2")
|
||||
new_key = new_key.replace(f"block.{idx - 1}", "conv2")
|
||||
|
||||
else:
|
||||
new_key = new_key
|
||||
|
||||
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
# TODO resnet time_mixer.mix_factor
|
||||
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[
|
||||
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
|
||||
)
|
||||
|
||||
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[
|
||||
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
|
||||
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert (
|
||||
original_config["target"] in PORTED_VQVAES
|
||||
), f"{original_config['target']} has not yet been ported to diffusers."
|
||||
assert original_config["target"] in PORTED_VQVAES, (
|
||||
f"{original_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
|
||||
original_config = original_config["params"]
|
||||
|
||||
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert (
|
||||
original_diffusion_config["target"] in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config["target"] in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
|
||||
f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
|
||||
f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
|
||||
f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
|
||||
original_diffusion_config = original_diffusion_config["params"]
|
||||
original_transformer_config = original_transformer_config["params"]
|
||||
|
||||
@@ -122,7 +122,7 @@ _deps = [
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"ruff==0.1.5",
|
||||
"ruff==0.9.10",
|
||||
"safetensors>=0.3.1",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython<3.1.19",
|
||||
|
||||
@@ -29,7 +29,7 @@ deps = {
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"ruff": "ruff==0.1.5",
|
||||
"ruff": "ruff==0.9.10",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
|
||||
@@ -295,8 +295,7 @@ class IPAdapterMixin:
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
|
||||
@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_unet:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
|
||||
# Handle text encoder LoRAs.
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
|
||||
# Store alpha if present.
|
||||
if lora_name_alpha in state_dict:
|
||||
@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
)
|
||||
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
)
|
||||
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
)
|
||||
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
)
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
||||
@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in original_state_dict)
|
||||
if has_guidance:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
)
|
||||
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
)
|
||||
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
)
|
||||
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
)
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
||||
|
||||
@@ -26,6 +26,7 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
@@ -41,7 +42,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["cache_utils"] = ["CacheMixin"]
|
||||
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
|
||||
@@ -205,7 +205,7 @@ def load_state_dict(
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -211,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
|
||||
def _init_vectorized_inputs(self, norm_type):
|
||||
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert (
|
||||
self.config.num_vector_embeds is not None
|
||||
), "Transformer2DModel over discrete input must provide num_embed"
|
||||
assert self.config.num_vector_embeds is not None, (
|
||||
"Transformer2DModel over discrete input must provide num_embed"
|
||||
)
|
||||
|
||||
self.height = self.config.sample_size
|
||||
self.width = self.config.sample_size
|
||||
|
||||
@@ -791,7 +791,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
|
||||
if transcription is None:
|
||||
if self.text_encoder_2.config.model_type == "vits":
|
||||
raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
|
||||
raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
|
||||
elif transcription is not None and (
|
||||
not isinstance(transcription, str) and not isinstance(transcription, list)
|
||||
):
|
||||
|
||||
@@ -657,7 +657,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -665,7 +665,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
|
||||
@@ -1130,7 +1130,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
||||
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.transformer` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -507,7 +507,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -515,7 +515,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -574,7 +574,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -582,7 +582,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -341,9 +341,9 @@ class AnimateDiffFreeNoiseMixin:
|
||||
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
negative_prompt_interpolation_embeds[
|
||||
start_frame : end_frame + 1
|
||||
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
|
||||
self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_interpolation_embeds
|
||||
negative_prompt_embeds = negative_prompt_interpolation_embeds
|
||||
|
||||
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
"""
|
||||
|
||||
_load_connected_pipes = True
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -579,7 +579,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -95,13 +95,13 @@ class OmniGenMultiModalProcessor:
|
||||
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
||||
|
||||
unique_image_ids = sorted(set(image_ids))
|
||||
assert unique_image_ids == list(
|
||||
range(1, len(unique_image_ids) + 1)
|
||||
), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
||||
assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
|
||||
f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
||||
)
|
||||
# total images must be the same as the number of image tags
|
||||
assert (
|
||||
len(unique_image_ids) == len(input_images)
|
||||
), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
||||
assert len(unique_image_ids) == len(input_images), (
|
||||
f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
||||
)
|
||||
|
||||
input_images = [input_images[x - 1] for x in image_ids]
|
||||
|
||||
|
||||
@@ -604,7 +604,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -612,7 +612,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
@@ -1340,7 +1340,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -683,7 +683,7 @@ class StableDiffusionPAGInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -691,7 +691,7 @@ class StableDiffusionPAGInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -1191,7 +1191,7 @@ class StableDiffusionPAGInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -737,7 +737,7 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -745,7 +745,7 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -1509,7 +1509,7 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -575,7 +575,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -323,9 +323,7 @@ def maybe_raise_or_warn(
|
||||
model_cls = unwrapped_sub_model.__class__
|
||||
|
||||
if not issubclass(model_cls, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
|
||||
)
|
||||
raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
|
||||
@@ -983,9 +983,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
fields = torch.cat(fields, dim=1)
|
||||
fields = fields.float()
|
||||
|
||||
assert (
|
||||
len(fields.shape) == 3 and fields.shape[-1] == 1
|
||||
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
|
||||
assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
|
||||
f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
|
||||
)
|
||||
|
||||
fields = fields.reshape(1, *([grid_size] * 3))
|
||||
|
||||
@@ -1039,9 +1039,9 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
textures = textures.float()
|
||||
|
||||
# 3.3 augument the mesh with texture data
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(
|
||||
texture_channels
|
||||
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
|
||||
f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
)
|
||||
|
||||
for m, texture in zip(raw_meshes, textures):
|
||||
texture = texture[: len(m.verts)]
|
||||
|
||||
@@ -584,7 +584,7 @@ class StableAudioPipeline(DiffusionPipeline):
|
||||
|
||||
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
|
||||
raise ValueError(
|
||||
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
|
||||
f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
|
||||
)
|
||||
|
||||
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
|
||||
|
||||
@@ -335,7 +335,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -475,7 +475,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user