Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 57afaff270 | |||
| e7b2032082 | |||
| 2d5e9c2e39 | |||
| d68635f950 | |||
| 371f765908 | |||
| 75aee39eac | |||
| 215e6804d3 | |||
| 9254d1f39a | |||
| e1bdcc7af3 | |||
| 84905ca728 | |||
| 6f336650c3 | |||
| 06a042cd0e | |||
| 8772496586 | |||
| 35fd84be27 | |||
| f2756253e6 | |||
| 0071478d9e | |||
| 7c8cab313e | |||
| ca9ed5e8d1 |
@@ -34,11 +34,6 @@ jobs:
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: LoRA
|
||||
framework: lora
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_lora
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
@@ -94,14 +89,6 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests
|
||||
if: ${{ matrix.config.framework == 'lora' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
|
||||
@@ -26,9 +26,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -42,6 +42,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Marigold monocular depth prediction pipeline.
|
||||
|
||||
@@ -49,6 +49,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -195,7 +196,7 @@ def import_model_class_from_model_name_or_path(
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None):
|
||||
img_str = ""
|
||||
if image_logs is not None:
|
||||
img_str = "You can find some example images below.\n"
|
||||
@@ -209,27 +210,25 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
||||
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- t2iadapter
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# t2iadapter-{repo_id}
|
||||
|
||||
These are t2iadapter weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "t2iadapter"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
|
||||
@@ -45,6 +45,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -75,21 +76,7 @@ def save_model_card(
|
||||
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
||||
img_str += "\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {args.pretrained_model_name_or_path}
|
||||
datasets:
|
||||
- {args.dataset_name}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# Text-to-image finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
||||
@@ -132,10 +119,21 @@ These are the key hyperparameters used during training:
|
||||
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
||||
"""
|
||||
|
||||
model_card += wandb_info
|
||||
model_description += wandb_info
|
||||
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
|
||||
|
||||
@@ -45,6 +45,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDif
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -61,26 +62,31 @@ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str,
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
@@ -58,6 +58,7 @@ from diffusers.utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -70,33 +71,20 @@ logger = get_logger(__name__)
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
dataset_name=str,
|
||||
train_text_encoder=False,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
train_text_encoder: bool = False,
|
||||
repo_folder: str = None,
|
||||
vae_path: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
dataset: {dataset_name}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
@@ -106,8 +94,19 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
|
||||
@@ -48,6 +48,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, U
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -77,29 +78,33 @@ def save_model_card(
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
dataset: {dataset_name}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# Text-to-image finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
|
||||
This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
|
||||
{img_str}
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
|
||||
@@ -167,7 +167,10 @@ vae_conversion_map_attn = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
@@ -321,11 +324,18 @@ if __name__ == "__main__":
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Convert text encoder 1
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Convert text encoder 2
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
# We call the `.T.contiguous()` to match what's done in
|
||||
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
|
||||
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
|
||||
"conditioner.embedders.1.model.text_projection.weight"
|
||||
).T.contiguous()
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
@@ -170,7 +170,10 @@ vae_extra_conversion_map = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
|
||||
@@ -126,8 +126,8 @@ _deps = [
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
"torch>=1.4,<2.2.0",
|
||||
"torchvision<0.17",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.25.1",
|
||||
"urllib3<=2.0.0",
|
||||
]
|
||||
|
||||
@@ -38,8 +38,8 @@ deps = {
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4,<2.2.0",
|
||||
"torchvision": "torchvision<0.17",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
}
|
||||
|
||||
@@ -38,6 +38,9 @@ class FromOriginalVAEMixin:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config_file (`str`, *optional*):
|
||||
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
@@ -65,6 +68,13 @@ class FromOriginalVAEMixin:
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
@@ -92,6 +102,7 @@ class FromOriginalVAEMixin:
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -103,6 +114,13 @@ class FromOriginalVAEMixin:
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
|
||||
if (config_file is not None) and (original_config_file is not None):
|
||||
raise ValueError(
|
||||
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||
)
|
||||
|
||||
original_config_file = original_config_file or config_file
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
@@ -118,7 +136,10 @@ class FromOriginalVAEMixin:
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
component = create_diffusers_vae_model_from_ldm(
|
||||
class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor
|
||||
)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
@@ -166,8 +166,7 @@ class IPAdapterMixin:
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=Path(subfolder, "image_encoder").as_posix(),
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.image_encoder = image_encoder
|
||||
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
||||
|
||||
|
||||
+89
-302
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -26,7 +25,7 @@ from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .. import __version__
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
@@ -34,7 +33,6 @@ from ..utils import (
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
@@ -51,10 +49,9 @@ from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -106,6 +103,9 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
@@ -397,16 +397,17 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith("unet.unet") for key in keys):
|
||||
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
||||
deprecate("unet.unet keys", "0.27", deprecation_message)
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
@@ -427,9 +428,7 @@ class LoraLoaderMixin:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
|
||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
@@ -518,6 +517,11 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -539,34 +543,21 @@ class LoraLoaderMixin:
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
else:
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -576,84 +567,25 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft import LoraConfig
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
||||
)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
else:
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
||||
for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(
|
||||
text_encoder_lora_state_dict, strict=False
|
||||
)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@@ -689,6 +621,8 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -705,8 +639,6 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
@@ -754,118 +686,20 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if USE_PEFT_BACKEND:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
|
||||
remove_method = recurse_remove_peft_layers
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
|
||||
# In case text encoder have no Lora attached
|
||||
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
if USE_PEFT_BACKEND:
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
attn_module.k_proj.lora_linear_layer = None
|
||||
attn_module.v_proj.lora_linear_layer = None
|
||||
attn_module.out_proj.lora_linear_layer = None
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_linear_layer = None
|
||||
mlp_module.fc2.lora_linear_layer = None
|
||||
|
||||
@classmethod
|
||||
def _modify_text_encoder(
|
||||
cls,
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alphas=None,
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
|
||||
|
||||
lora_parameters.extend(model.lora_linear_layer.parameters())
|
||||
return model
|
||||
|
||||
# First, remove any monkey-patch that might have been applied before
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
lora_parameters = []
|
||||
network_alphas = {} if network_alphas is None else network_alphas
|
||||
is_network_alphas_populated = len(network_alphas) > 0
|
||||
|
||||
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
|
||||
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
|
||||
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
||||
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
||||
|
||||
if isinstance(rank, dict):
|
||||
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
attn_module.q_proj = create_patched_linear_lora(
|
||||
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.k_proj = create_patched_linear_lora(
|
||||
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.v_proj = create_patched_linear_lora(
|
||||
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.out_proj = create_patched_linear_lora(
|
||||
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
|
||||
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
mlp_module.fc1 = create_patched_linear_lora(
|
||||
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
|
||||
)
|
||||
mlp_module.fc2 = create_patched_linear_lora(
|
||||
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
|
||||
)
|
||||
|
||||
if is_network_alphas_populated and len(network_alphas) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
|
||||
return lora_parameters
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
@@ -1039,6 +873,8 @@ class LoraLoaderMixin:
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
if self.num_fused_loras > 1:
|
||||
@@ -1050,52 +886,26 @@ class LoraLoaderMixin:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
||||
"backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
|
||||
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1120,40 +930,18 @@ class LoraLoaderMixin:
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if unfuse_unet:
|
||||
if not USE_PEFT_BACKEND:
|
||||
unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1434,6 +1222,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
@@ -1538,17 +1329,13 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if USE_PEFT_BACKEND:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@@ -175,6 +175,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
}
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
@@ -518,7 +519,10 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
scaling_factor = scaling_factor or original_config["model"]["params"]["scale_factor"]
|
||||
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
elif scaling_factor is None:
|
||||
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
@@ -1112,7 +1116,6 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
|
||||
text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
|
||||
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
|
||||
|
||||
else:
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
@@ -1174,7 +1177,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
|
||||
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from importlib import import_module
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@@ -509,6 +510,15 @@ class Attention(nn.Module):
|
||||
# The `Attention` class can call different attention processors / attention functions
|
||||
# here we simply pass along all tensors to the selected processor class
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_transformers_available
|
||||
|
||||
|
||||
@@ -82,6 +82,9 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
|
||||
class PatchedLoraProjection(torch.nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__()
|
||||
from ..models.lora import LoRALinearLayer
|
||||
|
||||
@@ -293,10 +296,16 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
||||
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
@@ -371,10 +380,15 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
||||
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -164,14 +163,6 @@ def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool
|
||||
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
|
||||
|
||||
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
class AutoPipelineForText2Image(ConfigMixin):
|
||||
r"""
|
||||
|
||||
@@ -391,7 +382,7 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
|
||||
expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
@@ -668,7 +659,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
|
||||
expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
@@ -943,7 +934,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
|
||||
expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
|
||||
@@ -1423,6 +1423,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
@@ -1472,7 +1473,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
hook.remove()
|
||||
|
||||
# make sure the model is in the same state as before calling it
|
||||
self.enable_model_cpu_offload()
|
||||
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
@@ -1508,6 +1509,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,64 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class PEFTLoRALoading(unittest.TestCase):
|
||||
def get_dummy_inputs(self):
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"generator": torch.manual_seed(0),
|
||||
}
|
||||
return pipeline_inputs
|
||||
|
||||
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
# For details on how the LoRA was obtained, refer to:
|
||||
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
@@ -21,6 +21,7 @@ from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
@@ -48,6 +49,20 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
|
||||
|
||||
|
||||
class AutoPipelineFastTest(unittest.TestCase):
|
||||
@property
|
||||
def dummy_image_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPVisionConfig(
|
||||
hidden_size=1,
|
||||
projection_dim=1,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
image_size=1,
|
||||
intermediate_size=1,
|
||||
patch_size=1,
|
||||
)
|
||||
return CLIPVisionModelWithProjection(config)
|
||||
|
||||
def test_from_pipe_consistent(self):
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
|
||||
@@ -204,6 +219,20 @@ class AutoPipelineFastTest(unittest.TestCase):
|
||||
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
|
||||
assert "controlnet" in pipe_control_img2img.components
|
||||
|
||||
def test_from_pipe_optional_components(self):
|
||||
image_encoder = self.dummy_image_encoder
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe",
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
|
||||
assert pipe.image_encoder is not None
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pipe(pipe, image_encoder=None)
|
||||
assert pipe.image_encoder is None
|
||||
|
||||
|
||||
@slow
|
||||
class AutoPipelineIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -36,10 +36,10 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import CaptureLogger, torch_device
|
||||
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
@@ -48,6 +48,9 @@ from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SchedulerObject(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@@ -253,6 +256,60 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_classes = ()
|
||||
forward_default_kwargs = ()
|
||||
|
||||
@property
|
||||
def default_num_inference_steps(self):
|
||||
return 50
|
||||
|
||||
@property
|
||||
def default_timestep(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
try:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = self.scheduler_classes[0](**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
timestep = scheduler.timesteps[0]
|
||||
except NotImplementedError:
|
||||
logger.warning(
|
||||
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
|
||||
f" `default_timestep` will be set to the default value of 1."
|
||||
)
|
||||
timestep = 1
|
||||
|
||||
return timestep
|
||||
|
||||
# NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively,
|
||||
# default_timestep comes earlier in the timestep schedule than default_timestep_2)
|
||||
@property
|
||||
def default_timestep_2(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
try:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = self.scheduler_classes[0](**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
if len(scheduler.timesteps) >= 2:
|
||||
timestep_2 = scheduler.timesteps[1]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep"
|
||||
f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0"
|
||||
f" will be used."
|
||||
)
|
||||
timestep_2 = 0
|
||||
except NotImplementedError:
|
||||
logger.warning(
|
||||
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
|
||||
f" `default_timestep_2` will be set to the default value of 0."
|
||||
)
|
||||
timestep_2 = 0
|
||||
|
||||
return timestep_2
|
||||
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
@@ -313,6 +370,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
time_step = time_step if time_step is not None else self.default_timestep
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
|
||||
@@ -371,6 +429,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
time_step = time_step if time_step is not None else self.default_timestep
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
@@ -411,10 +470,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
timestep = 1
|
||||
timestep = self.default_timestep
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
timestep = float(timestep)
|
||||
|
||||
@@ -497,10 +556,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
timestep_0 = 1
|
||||
timestep_1 = 0
|
||||
timestep_0 = self.default_timestep
|
||||
timestep_1 = self.default_timestep_2
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
@@ -558,9 +617,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", 50)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
timestep = 0
|
||||
timestep = self.default_timestep
|
||||
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
|
||||
timestep = 1
|
||||
|
||||
@@ -644,7 +703,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
continue
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(100)
|
||||
scheduler.set_timesteps(self.default_num_inference_steps)
|
||||
|
||||
sample = self.dummy_sample.to(torch_device)
|
||||
if scheduler_class == CMStochasticIterativeScheduler:
|
||||
|
||||
Reference in New Issue
Block a user