Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b56112db6e | |||
| f50de75b69 | |||
| 579bb5f418 |
@@ -28,9 +28,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio\
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
onnxruntime \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
sections:
|
||||
- local: api/models/overview
|
||||
title: Overview
|
||||
- local: api/models/auto_model
|
||||
title: AutoModel
|
||||
- sections:
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoModel
|
||||
|
||||
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel, AutoPipelineForText2Image
|
||||
|
||||
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
|
||||
```
|
||||
|
||||
|
||||
## AutoModel
|
||||
|
||||
[[autodoc]] AutoModel
|
||||
- all
|
||||
- from_pretrained
|
||||
@@ -839,9 +839,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -1605,7 +1605,7 @@ def main(args):
|
||||
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -200,8 +200,7 @@ Special VAE used for training: {vae_path}.
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
lora,
|
||||
"template:sd-lora",
|
||||
"stable-diffusion",
|
||||
"template:sd-lora" "stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
@@ -725,9 +724,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -747,9 +746,9 @@ class TokenEmbeddingsHandler:
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
self.embeddings_settings[
|
||||
f"original_embeddings_{idx}"
|
||||
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -1323,7 +1322,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -890,9 +890,9 @@ class TokenEmbeddingsHandler:
|
||||
idx = 0
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
|
||||
assert all(isinstance(tok, str) for tok in inserting_toks), (
|
||||
"All elements in inserting_toks should be strings."
|
||||
)
|
||||
assert all(
|
||||
isinstance(tok, str) for tok in inserting_toks
|
||||
), "All elements in inserting_toks should be strings."
|
||||
|
||||
self.inserting_toks = inserting_toks
|
||||
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
@@ -912,9 +912,9 @@ class TokenEmbeddingsHandler:
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
self.embeddings_settings[
|
||||
f"original_embeddings_{idx}"
|
||||
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -1647,7 +1647,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -720,7 +720,7 @@ def main(args):
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num training steps = {args.max_train_steps}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
|
||||
|
||||
@@ -1138,7 +1138,7 @@ def main(args):
|
||||
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1159,7 +1159,7 @@ def main(args):
|
||||
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1103,7 +1103,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `default_mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -686,7 +686,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -362,7 +362,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
|
||||
f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1184,7 +1184,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
|
||||
f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
|
||||
)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -701,7 +701,7 @@ class StableDiffusionXLControlNetTileSRPipeline(
|
||||
raise ValueError("`max_tile_size` cannot be None.")
|
||||
elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
|
||||
raise ValueError(
|
||||
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}."
|
||||
f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}."
|
||||
)
|
||||
if tile_gaussian_sigma is None:
|
||||
raise ValueError("`tile_gaussian_sigma` cannot be None.")
|
||||
|
||||
@@ -488,7 +488,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -496,7 +496,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -907,12 +907,12 @@ def create_controller(
|
||||
|
||||
# reweight
|
||||
if edit_type == "reweight":
|
||||
assert equalizer_words is not None and equalizer_strengths is not None, (
|
||||
"To use reweight edit, please specify equalizer_words and equalizer_strengths."
|
||||
)
|
||||
assert len(equalizer_words) == len(equalizer_strengths), (
|
||||
"equalizer_words and equalizer_strengths must be of same length."
|
||||
)
|
||||
assert (
|
||||
equalizer_words is not None and equalizer_strengths is not None
|
||||
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
|
||||
assert len(equalizer_words) == len(
|
||||
equalizer_strengths
|
||||
), "equalizer_words and equalizer_strengths must be of same length."
|
||||
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
|
||||
return AttentionReweight(
|
||||
prompts,
|
||||
|
||||
@@ -1738,7 +1738,7 @@ class StyleAlignedSDXLPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -689,7 +689,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents + num_channels_image}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -1028,7 +1028,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -1036,7 +1036,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -2050,7 +2050,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -1578,7 +1578,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
|
||||
@@ -288,7 +288,8 @@ class UFOGenScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
|
||||
@@ -89,7 +89,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
|
||||
|
||||
# Set alpha parameter
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
@@ -901,7 +901,7 @@ def main(args):
|
||||
unet_ = accelerator.unwrap_model(unet)
|
||||
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
|
||||
unet_state_dict = {
|
||||
f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
|
||||
|
||||
# Set alpha parameter
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
@@ -50,11 +50,9 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
|
||||
total = 0
|
||||
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
|
||||
|
||||
with (
|
||||
open(f"{class_data_dir}/caption.txt", "w") as f1,
|
||||
open(f"{class_data_dir}/urls.txt", "w") as f2,
|
||||
open(f"{class_data_dir}/images.txt", "w") as f3,
|
||||
):
|
||||
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
|
||||
f"{class_data_dir}/images.txt", "w"
|
||||
) as f3:
|
||||
while total < num_class_images:
|
||||
images = class_images[count]
|
||||
count += 1
|
||||
|
||||
@@ -731,18 +731,18 @@ def main(args):
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
if args.real_prior:
|
||||
assert (class_images_dir / "images").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert len(list((class_images_dir / "images").iterdir())) == args.num_class_images, (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (class_images_dir / "caption.txt").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (class_images_dir / "images.txt").exists(), (
|
||||
f'Please run: python retrieve.py --class_prompt "{concept["class_prompt"]}" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'
|
||||
)
|
||||
assert (
|
||||
class_images_dir / "images"
|
||||
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (
|
||||
len(list((class_images_dir / "images").iterdir())) == args.num_class_images
|
||||
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (
|
||||
class_images_dir / "caption.txt"
|
||||
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
assert (
|
||||
class_images_dir / "images.txt"
|
||||
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
|
||||
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
|
||||
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
|
||||
args.concepts_list[i] = concept
|
||||
|
||||
@@ -1014,7 +1014,7 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
|
||||
@@ -982,7 +982,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
|
||||
@@ -1294,7 +1294,7 @@ def main(args):
|
||||
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1053,7 +1053,7 @@ def main(args):
|
||||
lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1064,7 +1064,7 @@ def main(args):
|
||||
lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -1355,7 +1355,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -118,7 +118,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -1286,7 +1286,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
assert pipeline.transformer.config.in_channels == initial_channels * 2, (
|
||||
f"{pipeline.transformer.config.in_channels=}"
|
||||
)
|
||||
assert (
|
||||
pipeline.transformer.config.in_channels == initial_channels * 2
|
||||
), f"{pipeline.transformer.config.in_channels=}"
|
||||
|
||||
pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -954,7 +954,7 @@ def main(args):
|
||||
|
||||
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
|
||||
transformer_lora_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v
|
||||
f'{k.replace("transformer.", "")}': v
|
||||
for k, v in lora_state_dict.items()
|
||||
if k.startswith("transformer.") and "lora" in k
|
||||
}
|
||||
|
||||
@@ -1081,9 +1081,9 @@ class AutoConfig:
|
||||
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
|
||||
)
|
||||
|
||||
pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (
|
||||
textual_inversion_path.model_path
|
||||
)
|
||||
pretrained_model_name_or_paths[
|
||||
pretrained_model_name_or_paths.index(search_word)
|
||||
] = textual_inversion_path.model_path
|
||||
|
||||
self.load_textual_inversion(
|
||||
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
|
||||
|
||||
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"]
|
||||
assert torch.count_nonzero(tokens - 49407) == 2, (
|
||||
f"String '{string}' maps to more than a single token. Please use another string"
|
||||
)
|
||||
assert (
|
||||
torch.count_nonzero(tokens - 49407) == 2
|
||||
), f"String '{string}' maps to more than a single token. Please use another string"
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
|
||||
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], (
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
)
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
@@ -619,7 +619,7 @@ def main(args):
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB", ranks=[0])
|
||||
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
@@ -803,20 +803,21 @@ def parse_args(input_args=None):
|
||||
"--control_type",
|
||||
type=str,
|
||||
default="canny",
|
||||
help=("The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."),
|
||||
help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformer_layers_per_block",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."),
|
||||
help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_style_controlnet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
|
||||
"Use the old style controlnet, which is a single transformer layer with"
|
||||
" a single head. Defaults to False."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
|
||||
|
||||
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
|
||||
|
||||
|
||||
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
|
||||
if is_final_validation:
|
||||
if args.mixed_precision == "fp16":
|
||||
|
||||
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
|
||||
|
||||
|
||||
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
|
||||
if is_final_validation:
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -683,7 +683,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
|
||||
|
||||
|
||||
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
||||
logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
|
||||
|
||||
if is_final_validation:
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -790,7 +790,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
+1
-1
@@ -783,7 +783,7 @@ def main(args):
|
||||
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -26,7 +26,8 @@
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"from diffusers import StableDiffusionGLIGENPipeline"
|
||||
"import torch\n",
|
||||
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -35,25 +36,28 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import CLIPTextModel, CLIPTokenizer\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import diffusers\n",
|
||||
"from diffusers import (\n",
|
||||
" AutoencoderKL,\n",
|
||||
" DDPMScheduler,\n",
|
||||
" EulerDiscreteScheduler,\n",
|
||||
" UNet2DConditionModel,\n",
|
||||
" UniPCMultistepScheduler,\n",
|
||||
" EulerDiscreteScheduler,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
|
||||
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
|
||||
"\n",
|
||||
"pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
|
||||
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
|
||||
"\n",
|
||||
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
|
||||
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
|
||||
"text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
|
||||
"vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
|
||||
"text_encoder = CLIPTextModel.from_pretrained(\n",
|
||||
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
|
||||
")\n",
|
||||
"vae = AutoencoderKL.from_pretrained(\n",
|
||||
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
|
||||
")\n",
|
||||
"# unet = UNet2DConditionModel.from_pretrained(\n",
|
||||
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
|
||||
"# )\n",
|
||||
@@ -67,7 +71,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
|
||||
"unet = UNet2DConditionModel.from_pretrained(\n",
|
||||
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -102,9 +108,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
|
||||
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
|
||||
"\n",
|
||||
@@ -114,8 +117,10 @@
|
||||
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
|
||||
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
|
||||
"\n",
|
||||
"prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
|
||||
"gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
|
||||
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
|
||||
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"boxes = np.array([x[1] for x in gen_boxes])\n",
|
||||
"boxes = boxes / 512\n",
|
||||
@@ -161,7 +166,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
|
||||
"diffusers.utils.make_image_grid(images, 4, len(images)//4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -174,7 +179,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "densecaption",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -192,5 +197,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
|
||||
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
|
||||
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
+6
-6
@@ -763,9 +763,9 @@ def main(args):
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
|
||||
"Instance data dir and prompt inputs are not of the same length."
|
||||
)
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
@@ -788,9 +788,9 @@ def main(args):
|
||||
negative_validation_prompts.append(None)
|
||||
args.validation_negative_prompt = negative_validation_prompts
|
||||
|
||||
assert num_of_validation_prompts == len(negative_validation_prompts), (
|
||||
"The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
)
|
||||
assert num_of_validation_prompts == len(
|
||||
negative_validation_prompts
|
||||
), "The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
|
||||
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
|
||||
|
||||
|
||||
@@ -830,9 +830,9 @@ def main():
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = get_mask(tokenizer, accelerator)
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -886,9 +886,9 @@ def main():
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -663,7 +663,8 @@ class PromptDiffusionPipeline(
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
|
||||
f"You have passed a list of images of length {len(image_pair)}."
|
||||
f"Make sure the list size equals to two."
|
||||
)
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
|
||||
+2
-2
@@ -173,7 +173,7 @@ class TrainSD:
|
||||
if not dataloader_exception:
|
||||
xm.wait_device_ops()
|
||||
total_time = time.time() - last_time
|
||||
print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
|
||||
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
|
||||
else:
|
||||
print("dataloader exception happen, skip result")
|
||||
return
|
||||
@@ -622,7 +622,7 @@ def main(args):
|
||||
num_devices_per_host = num_devices // num_hosts
|
||||
if xm.is_master_ordinal():
|
||||
print("***** Running training *****")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
|
||||
print(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
|
||||
)
|
||||
|
||||
+1
-1
@@ -1057,7 +1057,7 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
|
||||
raise ValueError(
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
|
||||
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
|
||||
+1
-1
@@ -1021,7 +1021,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
|
||||
+2
-2
@@ -118,7 +118,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -1336,7 +1336,7 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
+1
-1
@@ -750,7 +750,7 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -765,7 +765,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -767,7 +767,7 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -910,9 +910,9 @@ def main():
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -965,12 +965,12 @@ def main():
|
||||
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
|
||||
orig_embeds_params_2[index_no_updates_2]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
|
||||
index_no_updates_2
|
||||
] = orig_embeds_params_2[index_no_updates_2]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -177,7 +177,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
@@ -262,7 +262,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--output_dir {tmpdir}
|
||||
--use_ema
|
||||
--seed=0
|
||||
@@ -377,7 +377,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--checkpoints_total_limit=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
@@ -653,15 +653,15 @@ def main():
|
||||
try:
|
||||
# Gets the resolution of the timm transformation after centercrop
|
||||
timm_centercrop_transform = timm_transform.transforms[1]
|
||||
assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
|
||||
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
)
|
||||
assert isinstance(
|
||||
timm_centercrop_transform, transforms.CenterCrop
|
||||
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
timm_model_resolution = timm_centercrop_transform.size[0]
|
||||
# Gets final normalization
|
||||
timm_model_normalization = timm_transform.transforms[-1]
|
||||
assert isinstance(timm_model_normalization, transforms.Normalize), (
|
||||
f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
)
|
||||
assert isinstance(
|
||||
timm_model_normalization, transforms.Normalize
|
||||
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
|
||||
except AssertionError as e:
|
||||
raise NotImplementedError(e)
|
||||
# Enable flash attention if asked
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ line-length = 119
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Never enforce `E501` (line length violations).
|
||||
ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
|
||||
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
|
||||
|
||||
# assert (old_output == new_output).all()
|
||||
print("skipping full vae equivalence check")
|
||||
print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
|
||||
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
|
||||
|
||||
return new_vae
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer - 1}.1"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.1"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
elif layer_type == "AttnUpBlock2D":
|
||||
for j in range(layers_per_block + 1):
|
||||
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer - 1}.2"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.2"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
||||
|
||||
@@ -261,9 +261,9 @@ def main(args):
|
||||
|
||||
model_name = args.model_path.split("/")[-1].split(".")[0]
|
||||
if not os.path.isfile(args.model_path):
|
||||
assert model_name == args.model_path, (
|
||||
f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
)
|
||||
assert (
|
||||
model_name == args.model_path
|
||||
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
args.model_path = download(model_name)
|
||||
|
||||
sample_rate = MODELS_MAP[model_name]["sample_rate"]
|
||||
@@ -290,9 +290,9 @@ def main(args):
|
||||
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
|
||||
|
||||
for key, value in renamed_state_dict.items():
|
||||
assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
|
||||
f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
)
|
||||
assert (
|
||||
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
|
||||
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
if key == "time_proj.weight":
|
||||
value = value.squeeze()
|
||||
|
||||
|
||||
@@ -52,18 +52,18 @@ for i in range(3):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(4):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i < 2:
|
||||
@@ -75,12 +75,12 @@ for i in range(3):
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
||||
|
||||
@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
@@ -137,20 +137,20 @@ for i in range(4):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
|
||||
@@ -47,36 +47,36 @@ for i in range(4):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
@@ -133,20 +133,20 @@ for i in range(4):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ def main(args):
|
||||
model_config = HunyuanDiT2DControlNetModel.load_config(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
||||
)
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
print(model_config)
|
||||
|
||||
for key in state_dict:
|
||||
|
||||
@@ -13,14 +13,15 @@ def main(args):
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
f"Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
|
||||
# input_size -> sample_size, text_dim -> cross_attention_dim
|
||||
for key in state_dict:
|
||||
|
||||
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
|
||||
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
|
||||
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
|
||||
self_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx }"
|
||||
cross_attention_index = 1 if not attention.add_self_attention else 2
|
||||
idx = (
|
||||
n * attention_idx + cross_attention_index
|
||||
if block_type == "up"
|
||||
else n * attention_idx + cross_attention_index + 1
|
||||
)
|
||||
cross_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx }"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
cross_attn_to_diffusers_checkpoint(
|
||||
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
|
||||
|
||||
block_out_channels = original_config["channels"]
|
||||
|
||||
assert len(set(original_config["depths"])) == 1, (
|
||||
"UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
)
|
||||
assert (
|
||||
len(set(original_config["depths"])) == 1
|
||||
), "UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
layers_per_block = original_config["depths"][0]
|
||||
|
||||
class_labels_dim = original_config["mapping_cond_dim"]
|
||||
|
||||
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.0.weight"
|
||||
f"blocks.0.{i+1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.0.bias"
|
||||
f"blocks.0.{i+1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.2.weight"
|
||||
f"blocks.0.{i+1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.2.bias"
|
||||
f"blocks.0.{i+1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.3.weight"
|
||||
f"blocks.0.{i+1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.3.bias"
|
||||
f"blocks.0.{i+1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.5.weight"
|
||||
f"blocks.0.{i+1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i + 1}.stack.5.bias"
|
||||
f"blocks.0.{i+1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert up_blocks (MochiUpBlock3D)
|
||||
@@ -197,35 +197,33 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
for block in range(3):
|
||||
for i in range(down_block_layers[block]):
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.proj.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.proj.bias"
|
||||
f"blocks.{block+1}.proj.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
@@ -269,133 +267,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.0.weight"
|
||||
f"layers.{i+1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.0.bias"
|
||||
f"layers.{i+1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.2.weight"
|
||||
f"layers.{i+1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.2.bias"
|
||||
f"layers.{i+1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.3.weight"
|
||||
f"layers.{i+1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.3.bias"
|
||||
f"layers.{i+1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.5.weight"
|
||||
f"layers.{i+1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 1}.stack.5.bias"
|
||||
f"layers.{i+1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert down_blocks (MochiDownBlock3D)
|
||||
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
|
||||
for block in range(3):
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.0.weight"
|
||||
f"layers.{block+4}.layers.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.0.bias"
|
||||
f"layers.{block+4}.layers.0.bias"
|
||||
)
|
||||
|
||||
for i in range(down_block_layers[block]):
|
||||
# Convert resnets
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
|
||||
f"layers.{block+4}.layers.{i+1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
|
||||
f"layers.{block+4}.layers.{i+1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
# Convert resnets
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.0.weight"
|
||||
f"layers.{i+7}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.0.bias"
|
||||
f"layers.{i+7}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.2.weight"
|
||||
f"layers.{i+7}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.2.bias"
|
||||
f"layers.{i+7}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.3.weight"
|
||||
f"layers.{i+7}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.3.bias"
|
||||
f"layers.{i+7}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.5.weight"
|
||||
f"layers.{i+7}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.stack.5.bias"
|
||||
f"layers.{i+7}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.attn_block.attn.out.weight"
|
||||
f"layers.{i+7}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.attn_block.attn.out.bias"
|
||||
f"layers.{i+7}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.attn_block.norm.weight"
|
||||
f"layers.{i+7}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i + 7}.attn_block.norm.bias"
|
||||
f"layers.{i+7}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert output layers
|
||||
|
||||
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
|
||||
# get idx of the layer
|
||||
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
|
||||
|
||||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
|
||||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
|
||||
|
||||
if "encoder" in new_key:
|
||||
for i in range(3):
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
|
||||
else:
|
||||
for i in range(2, 5):
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
|
||||
|
||||
new_key = new_key.replace("layers.0.beta", "snake1.beta")
|
||||
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
|
||||
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
|
||||
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
|
||||
|
||||
if idx == num_autoencoder_layers + 1:
|
||||
new_key = new_key.replace(f"block.{idx - 1}", "snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}", "snake1")
|
||||
elif idx == num_autoencoder_layers + 2:
|
||||
new_key = new_key.replace(f"block.{idx - 1}", "conv2")
|
||||
new_key = new_key.replace(f"block.{idx-1}", "conv2")
|
||||
|
||||
else:
|
||||
new_key = new_key
|
||||
|
||||
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
# TODO resnet time_mixer.mix_factor
|
||||
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
new_checkpoint[
|
||||
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
|
||||
)
|
||||
|
||||
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
new_checkpoint[
|
||||
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
|
||||
@@ -53,12 +53,7 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [
|
||||
key
|
||||
for key in down_blocks[i]
|
||||
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
|
||||
]
|
||||
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
@@ -72,10 +67,6 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
@@ -94,11 +85,8 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key
|
||||
for key in up_blocks[block_id]
|
||||
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
@@ -112,10 +100,6 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
|
||||
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert original_config["target"] in PORTED_VQVAES, (
|
||||
f"{original_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert (
|
||||
original_config["target"] in PORTED_VQVAES
|
||||
), f"{original_config['target']} has not yet been ported to diffusers."
|
||||
|
||||
original_config = original_config["params"]
|
||||
|
||||
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
|
||||
f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
|
||||
f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
|
||||
f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert (
|
||||
original_diffusion_config["target"] in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config["target"] in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
|
||||
original_diffusion_config = original_diffusion_config["params"]
|
||||
original_transformer_config = original_transformer_config["params"]
|
||||
|
||||
@@ -122,7 +122,7 @@ _deps = [
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"ruff==0.9.10",
|
||||
"ruff==0.1.5",
|
||||
"safetensors>=0.3.1",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython<3.1.19",
|
||||
@@ -142,7 +142,6 @@ _deps = [
|
||||
"urllib3<=2.0.0",
|
||||
"black",
|
||||
"phonemizer",
|
||||
"opencv-python",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
|
||||
@@ -14,7 +14,6 @@ from .utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
@@ -156,7 +155,6 @@ else:
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"CacheMixin",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
@@ -199,7 +197,6 @@ else:
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
"TransformerTemporalModel",
|
||||
"UNet1DModel",
|
||||
"UNet2DConditionModel",
|
||||
"UNet2DModel",
|
||||
@@ -353,6 +350,7 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
@@ -518,19 +516,6 @@ else:
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_opencv_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_opencv_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["ConsisIDPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -746,7 +731,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
CacheMixin,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
@@ -788,7 +772,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
@@ -922,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
@@ -1100,15 +1084,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_opencv_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import ConsisIDPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -29,7 +29,7 @@ deps = {
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"ruff": "ruff==0.9.10",
|
||||
"ruff": "ruff==0.1.5",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
@@ -49,5 +49,4 @@ deps = {
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"black": "black",
|
||||
"phonemizer": "phonemizer",
|
||||
"opencv-python": "opencv-python",
|
||||
}
|
||||
|
||||
@@ -295,7 +295,8 @@ class IPAdapterMixin:
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
|
||||
@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_unet:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Handle text encoder LoRAs.
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
|
||||
state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
)
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Store alpha if present.
|
||||
if lora_name_alpha in state_dict:
|
||||
@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
||||
@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in original_state_dict)
|
||||
if has_guidance:
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
|
||||
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
|
||||
original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
)
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
||||
@@ -1608,64 +1608,3 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
||||
# https://github.com/kohya-ss/musubi-tuner
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
|
||||
def get_alpha_scales(down_weight, key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = original_state_dict.pop(key + ".alpha").item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for i in range(num_blocks):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -42,7 +42,6 @@ from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
@@ -4795,8 +4794,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
if any(k.startswith("diffusion_model.") for k in state_dict):
|
||||
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
||||
elif any(k.startswith("lora_unet_") for k in state_dict):
|
||||
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
|
||||
@@ -177,7 +177,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
@@ -639,9 +638,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
||||
model_type = "ltx-video-0.9.5"
|
||||
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
@@ -2406,41 +2403,13 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
}
|
||||
|
||||
if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
|
||||
@@ -26,7 +26,6 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
@@ -104,7 +103,6 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .auto_model import AutoModel
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderDC,
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
class AutoModel(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
||||
|
||||
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
||||
train the model, set it back in training mode with `model.train()`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`~ModelMixin.save_pretrained`].
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
from_flax (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a Flax checkpoint save file.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
The path to offload weights if `device_map` contains the value `"disk"`.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
||||
when there is some disk offload.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
||||
`huggingface-cli login`. You can also activate the special
|
||||
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
||||
firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
```
|
||||
|
||||
If you get the error message below, you need to finetune the weights for your downstream task:
|
||||
|
||||
```bash
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
library = importlib.import_module("diffusers")
|
||||
|
||||
model_cls = getattr(library, orig_class_name, None)
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
@@ -298,6 +298,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
@@ -311,15 +320,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ def load_state_dict(
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -211,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
|
||||
def _init_vectorized_inputs(self, norm_type):
|
||||
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert self.config.num_vector_embeds is not None, (
|
||||
"Transformer2DModel over discrete input must provide num_embed"
|
||||
)
|
||||
assert (
|
||||
self.config.num_vector_embeds is not None
|
||||
), "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = self.config.sample_size
|
||||
self.width = self.config.sample_size
|
||||
|
||||
@@ -10,7 +10,6 @@ from ..utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
@@ -156,6 +155,7 @@ else:
|
||||
]
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
@@ -414,18 +414,6 @@ else:
|
||||
"KolorsImg2ImgPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import (
|
||||
dummy_torch_and_transformers_and_opencv_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
|
||||
else:
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -524,6 +512,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
|
||||
from .consisid import ConsisIDPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
@@ -772,14 +761,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KolorsPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_opencv_objects import *
|
||||
else:
|
||||
from .consisid import ConsisIDPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Model,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
@@ -196,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
text_encoder: ClapModel,
|
||||
text_encoder_2: Union[T5EncoderModel, VitsModel],
|
||||
projection_model: AudioLDM2ProjectionModel,
|
||||
language_model: GPT2LMHeadModel,
|
||||
language_model: GPT2Model,
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
|
||||
feature_extractor: ClapFeatureExtractor,
|
||||
@@ -259,10 +259,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
device_type = torch_device.type
|
||||
device_str = device_type
|
||||
if gpu_id or torch_device.index:
|
||||
device_str = f"{device_str}:{gpu_id or torch_device.index}"
|
||||
device = torch.device(device_str)
|
||||
device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
@@ -319,9 +316,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
# forward pass to get next hidden states
|
||||
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
|
||||
output = self.language_model(**model_inputs, return_dict=True)
|
||||
|
||||
next_hidden_states = output.hidden_states[-1]
|
||||
next_hidden_states = output.last_hidden_state
|
||||
|
||||
# Update the model input
|
||||
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
||||
@@ -791,7 +788,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
|
||||
if transcription is None:
|
||||
if self.text_encoder_2.config.model_type == "vits":
|
||||
raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
|
||||
raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
|
||||
elif transcription is not None and (
|
||||
not isinstance(transcription, str) and not isinstance(transcription, list)
|
||||
):
|
||||
|
||||
@@ -5,7 +5,6 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_opencv_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -16,12 +15,12 @@ _import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_opencv_available()):
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_consisid"] = ["ConsisIDPipeline"]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
@@ -28,16 +29,12 @@ from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDPMScheduler
|
||||
from ...utils import is_opencv_available, logging, replace_example_docstring
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import ConsisIDPipelineOutput
|
||||
|
||||
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@@ -657,7 +657,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -665,7 +665,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
|
||||
@@ -1130,7 +1130,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
||||
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
||||
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.transformer` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
|
||||
@@ -507,7 +507,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -515,7 +515,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -574,7 +574,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -582,7 +582,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -341,9 +341,9 @@ class AnimateDiffFreeNoiseMixin:
|
||||
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
|
||||
self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
)
|
||||
negative_prompt_interpolation_embeds[
|
||||
start_frame : end_frame + 1
|
||||
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
|
||||
prompt_embeds = prompt_interpolation_embeds
|
||||
negative_prompt_embeds = negative_prompt_interpolation_embeds
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user