Compare commits

..

9 Commits

Author SHA1 Message Date
Pedro Cuenca f317695f6b mps: fix torch.where in svd codepath. 2023-12-18 18:58:09 +01:00
d8ahazard 6976cab7ca Fix possible re-conversion issues after extracting from safetensors (#6097)
* Fix possible re-conversion issues after extracting from diffusers

Properly rename specific vae keys.

* Whoops
2023-12-18 11:51:20 +01:00
Dhruv Nair fcbed3fa79 Fix SDXL Inpainting from single file with Refiner Model (#6147)
* update

* update

* update
2023-12-18 11:45:37 +01:00
Sayak Paul b98b314b7a [Training] remove depcreated method from lora scripts. (#6207)
remove depcreated method from lora scripts.
2023-12-18 15:52:43 +05:30
Omar Sanseviero 74558ff65b Nit fix to training params (#6200) 2023-12-18 11:06:16 +01:00
Yudong Jin 49644babd3 Fix the test script in examples/text_to_image/README.md (#6209)
* Update examples/text_to_image/README.md

* Update examples/text_to_image/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-18 15:36:00 +05:30
Sayak Paul 56b3b21693 [Refactor autoencoders] feat: introduce autoencoders module (#6129)
* feat: introduce autoencoders module

* more changes for styling and copy fixing

* path changes in the docs.

* fix: import structure in init.

* fix controlnetxs import
2023-12-18 12:42:15 +05:30
Sayak Paul 9cef07da5a [Benchmarks] fix: lcm benchmarking reporting (#6198)
* fix: lcm benchmarking reporting

* fix generate_csv_dict call.
2023-12-17 15:32:11 +05:30
Sayak Paul 2d94c7838e [Core] feat: enable fused attention projections for other SD and SDXL pipelines (#6179)
* feat: enable fused attention projections for other SD and SDXL pipelines

* add: test for SD fused projections.
2023-12-16 08:45:54 +05:30
37 changed files with 676 additions and 552 deletions
+19
View File
@@ -162,6 +162,25 @@ class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
guidance_scale=1.0,
)
def benchmark(self, args):
flush()
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
benchmark_info = BenchmarkInfo(time=time, memory=memory)
pipeline_class_name = str(self.pipe.__class__.__name__)
flush()
csv_dict = generate_csv_dict(
pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
)
filepath = self.get_result_filepath(args)
write_to_csv(filepath, csv_dict)
print(f"Logs written to: {filepath}")
flush()
class ImageToImageBenchmark(TextToImageBenchmark):
pipeline_class = AutoPipelineForImage2Image
@@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3)
## AsymmetricAutoencoderKL
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
[[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -54,4 +54,4 @@ image
## AutoencoderTinyOutput
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
[[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput
+2 -2
View File
@@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url)
## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
## FlaxAutoencoderKL
@@ -186,7 +186,7 @@ accelerate launch train_unconditional.py \
If you're training with more than one GPU, add the `--multi_gpu` parameter to the training command:
```bash
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
accelerate launch --multi_gpu train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \
--output_dir="ddpm-ema-flowers-64" \
--mixed_precision="fp16" \
+3 -50
View File
@@ -64,39 +64,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -880,16 +847,11 @@ def main(args):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
unet_lora_config = None
text_encoder_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -900,8 +862,6 @@ def main(args):
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
def load_model_hook(models, input_dir):
@@ -918,12 +878,10 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1324,22 +1282,17 @@ def main(args):
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_lora_config = text_encoder.peft_config["default"]
else:
text_encoder_state_dict = None
text_encoder_lora_config = None
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
# Final inference
@@ -64,39 +64,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -1033,20 +1000,13 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1058,9 +1018,6 @@ def main(args):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
def load_model_hook(models, input_dir):
@@ -1080,19 +1037,17 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1628,29 +1583,21 @@ def main(args):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
# Final inference
+4 -3
View File
@@ -101,8 +101,8 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
```python
import torch
from diffusers import StableDiffusionPipeline
model_path = "path_to_saved_model"
@@ -114,12 +114,13 @@ image.save("yoda-pokemon.png")
```
Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
```python
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
model_path = "path_to_saved_model"
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained("<initial model>", unet=unet, torch_dtype=torch.float16)
pipe.to("cuda")
@@ -54,39 +54,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
@@ -833,12 +800,10 @@ def main():
accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
unet_lora_config=unet_lora_config,
safe_serialization=True,
)
@@ -900,12 +865,10 @@ def main():
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
unet_lora_config=unet_lora_config,
)
if args.push_to_hub:
@@ -63,39 +63,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -682,20 +649,13 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -707,9 +667,6 @@ def main(args):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
def load_model_hook(models, input_dir):
@@ -729,19 +686,17 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1206,7 +1161,6 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
@@ -1214,23 +1168,15 @@ def main(args):
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
del unet
+1 -1
View File
@@ -12,9 +12,9 @@ from safetensors.torch import load_file as stl
from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
from diffusers.models.vae import Encoder
args = ArgumentParser()
@@ -159,6 +159,14 @@ vae_conversion_map_attn = [
("proj_out.", "proj_attn."),
]
# This is probably not the most ideal solution, but it does work.
vae_extra_conversion_map = [
("to_q", "q"),
("to_k", "k"),
("to_v", "v"),
("to_out.0", "proj_out"),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
@@ -178,11 +186,20 @@ def convert_vae_state_dict(vae_state_dict):
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
keys_to_rename = {}
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
for weight_name, real_weight_name in vae_extra_conversion_map:
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
keys_to_rename[k] = k.replace(weight_name, real_weight_name)
for k, v in keys_to_rename.items():
if k in new_state_dict:
print(f"Renaming {k} to {v}")
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
del new_state_dict[k]
return new_state_dict
+24 -133
View File
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -104,7 +103,7 @@ class LoraLoaderMixin:
`default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -115,7 +114,6 @@ class LoraLoaderMixin:
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
config=metadata,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
@@ -127,7 +125,6 @@ class LoraLoaderMixin:
text_encoder=getattr(self, self.text_encoder_name)
if not hasattr(self, "text_encoder")
else self.text_encoder,
config=metadata,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
@@ -222,7 +219,6 @@ class LoraLoaderMixin:
}
model_file = None
metadata = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
@@ -252,8 +248,6 @@ class LoraLoaderMixin:
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
@@ -300,7 +294,7 @@ class LoraLoaderMixin:
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas, metadata
return state_dict, network_alphas
@classmethod
def _best_guess_weight_name(
@@ -376,7 +370,7 @@ class LoraLoaderMixin:
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -390,8 +384,6 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
config: (`Dict`):
LoRA configuration parsed from the state dict.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
@@ -451,9 +443,7 @@ class LoraLoaderMixin:
if "lora_B" in key:
rank[key] = val.shape[1]
if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config["unet"])
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -494,7 +484,6 @@ class LoraLoaderMixin:
network_alphas,
text_encoder,
prefix=None,
config=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
@@ -513,8 +502,6 @@ class LoraLoaderMixin:
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
config (`Dict`):
LoRA configuration parsed from state dict.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
@@ -588,11 +575,10 @@ class LoraLoaderMixin:
if USE_PEFT_BACKEND:
from peft import LoraConfig
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -800,8 +786,6 @@ class LoraLoaderMixin:
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
unet_lora_config=None,
text_encoder_lora_config=None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -829,54 +813,21 @@ class LoraLoaderMixin:
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
if not USE_PEFT_BACKEND and not safe_serialization:
if unet_lora_config or text_encoder_lora_config:
raise ValueError(
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`."
)
elif USE_PEFT_BACKEND and safe_serialization:
from peft import LoraConfig
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
state_dict = {}
metadata = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
if unet_lora_layers:
prefix = "unet"
unet_state_dict = pack_weights(unet_lora_layers, prefix)
state_dict.update(unet_state_dict)
if unet_lora_config is not None:
unet_metadata = pack_metadata(unet_lora_config, prefix)
metadata.update(unet_metadata)
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers:
prefix = "text_encoder"
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
state_dict.update(text_encoder_state_dict)
if text_encoder_lora_config is not None:
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
metadata.update(text_encoder_metadata)
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
# Save the model
cls.write_lora_layers(
@@ -886,7 +837,6 @@ class LoraLoaderMixin:
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
metadata=metadata,
)
@staticmethod
@@ -897,11 +847,7 @@ class LoraLoaderMixin:
weight_name: str,
save_function: Callable,
safe_serialization: bool,
metadata=None,
):
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@@ -909,10 +855,8 @@ class LoraLoaderMixin:
if save_function is None:
if safe_serialization:
def save_function(weights, filename, metadata):
if metadata is None:
metadata = {"format": "pt"}
return safetensors.torch.save_file(weights, filename, metadata=metadata)
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
@@ -925,10 +869,7 @@ class LoraLoaderMixin:
else:
weight_name = LORA_WEIGHT_NAME
if save_function != torch.save:
save_function(state_dict, os.path.join(save_directory, weight_name), metadata)
else:
save_function(state_dict, os.path.join(save_directory, weight_name))
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def unload_lora_weights(self):
@@ -1360,7 +1301,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# pipeline.
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas, metadata = self.lora_state_dict(
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
@@ -1370,12 +1311,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=self.unet,
config=metadata,
adapter_name=adapter_name,
_pipeline=self,
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -1383,7 +1319,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
config=metadata,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
@@ -1396,7 +1331,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
config=metadata,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
@@ -1410,9 +1344,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
unet_lora_config=None,
text_encoder_lora_config=None,
text_encoder_2_lora_config=None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1440,63 +1371,24 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
if not USE_PEFT_BACKEND and not safe_serialization:
if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config:
raise ValueError(
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`."
)
elif USE_PEFT_BACKEND and safe_serialization:
from peft import LoraConfig
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
state_dict = {}
metadata = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
if unet_lora_layers:
prefix = "unet"
unet_state_dict = pack_weights(unet_lora_layers, prefix)
state_dict.update(unet_state_dict)
if unet_lora_config is not None:
unet_metadata = pack_metadata(unet_lora_config, prefix)
metadata.update(unet_metadata)
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
prefix = "text_encoder"
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
state_dict.update(text_encoder_state_dict)
if text_encoder_lora_config is not None:
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
metadata.update(text_encoder_metadata)
prefix = "text_encoder_2"
text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix)
state_dict.update(text_encoder_2_state_dict)
if text_encoder_2_lora_config is not None:
text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix)
metadata.update(text_encoder_2_metadata)
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
cls.write_lora_layers(
state_dict=state_dict,
@@ -1505,7 +1397,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
metadata=metadata,
)
def _remove_text_encoder_monkey_patch(self):
+4
View File
@@ -169,10 +169,12 @@ class FromSingleFileMixin:
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
text_encoder_2 = kwargs.pop("text_encoder_2", None)
vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None)
adapter = kwargs.pop("adapter", None)
tokenizer = kwargs.pop("tokenizer", None)
tokenizer_2 = kwargs.pop("tokenizer_2", None)
torch_dtype = kwargs.pop("torch_dtype", None)
@@ -274,8 +276,10 @@ class FromSingleFileMixin:
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
original_config_file=original_config_file,
config_files=config_files,
local_files_only=local_files_only,
+12 -10
View File
@@ -26,11 +26,11 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
@@ -58,11 +58,13 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .controlnetxs import ControlNetXSModel
from .dual_transformer_2d import DualTransformer2DModel
@@ -0,0 +1,5 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
@@ -16,10 +16,10 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
@@ -27,8 +27,8 @@ from .attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -16,14 +16,14 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
@@ -18,20 +18,20 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers import ConsistencyDecoderScheduler
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from ..utils.torch_utils import randn_tensor
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...schedulers import ConsistencyDecoderScheduler
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_utils import ModelMixin
from .unet_2d import UNet2DModel
from ..modeling_utils import ModelMixin
from ..unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -153,7 +153,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False
self.use_tiling = False
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -162,7 +162,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
@@ -170,7 +170,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
@@ -18,11 +18,11 @@ import numpy as np
import torch
import torch.nn as nn
from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
from .unet_2d_blocks import (
from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
from ..unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
+1 -1
View File
@@ -26,7 +26,7 @@ from ..utils import BaseOutput, logging
from .attention_processor import (
AttentionProcessor,
)
from .autoencoder_kl import AutoencoderKL
from .autoencoders import AutoencoderKL
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
+1 -1
View File
@@ -1334,7 +1334,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.ones(1, 1, device=image_only_indicator.device, dtype=self.mix_factor.dtype),
torch.sigmoid(self.mix_factor)[..., None],
)
+1 -1
View File
@@ -20,8 +20,8 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
@dataclass
@@ -23,6 +23,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import FusedAttnProcessor2_0
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -655,6 +656,65 @@ class AltDiffusionPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import FusedAttnProcessor2_0
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -715,6 +716,65 @@ class AltDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
@@ -1153,7 +1153,9 @@ def download_from_original_stable_diffusion_ckpt(
vae_path=None,
vae=None,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
config_files=None,
) -> DiffusionPipeline:
"""
@@ -1232,7 +1234,9 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
@@ -1339,7 +1343,11 @@ def download_from_original_stable_diffusion_ckpt(
else:
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
if num_in_channels is None and pipeline_class in [
StableDiffusionInpaintPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLControlNetInpaintPipeline,
]:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
@@ -1686,7 +1694,9 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL":
is_refiner = model_type == "SDXL-Refiner"
if (is_refiner is False) and (tokenizer is None):
try:
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
@@ -1695,7 +1705,11 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
if (is_refiner is False) and (text_encoder is None):
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
if tokenizer_2 is None:
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
@@ -1705,95 +1719,69 @@ def download_from_original_stable_diffusion_ckpt(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)
if text_encoder_2 is None:
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint,
config_name,
prefix="conditioner.embedders.1.model.",
prefix=prefix,
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)
if is_accelerate_available(): # SBM Now move model to cpu.
if model_type in ["SDXL", "SDXL-Refiner"]:
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
if is_accelerate_available(): # SBM Now move model to cpu.
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
if controlnet:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
elif adapter:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
adapter=adapter,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
tokenizer = None
text_encoder = None
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint,
config_name,
prefix="conditioner.embedders.0.model.",
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)
if is_accelerate_available(): # SBM Now move model to cpu.
if model_type in ["SDXL", "SDXL-Refiner"]:
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
pipe = StableDiffusionXLImg2ImgPipeline(
if controlnet:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
requires_aesthetics_score=True,
force_zeros_for_empty_prompt=False,
force_zeros_for_empty_prompt=True,
)
elif adapter:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
adapter=adapter,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
pipeline_kwargs = {
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"unet": unet,
"scheduler": scheduler,
}
if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
pipeline_class == StableDiffusionXLInpaintPipeline
):
pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
if is_refiner:
pipeline_kwargs.update({"force_zeros_for_empty_prompt": False})
pipe = pipeline_class(**pipeline_kwargs)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
@@ -23,6 +23,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import FusedAttnProcessor2_0
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -650,6 +651,67 @@ class StableDiffusionPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import FusedAttnProcessor2_0
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -718,6 +719,67 @@ class StableDiffusionImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import FusedAttnProcessor2_0
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -844,6 +845,67 @@ class StableDiffusionInpaintPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
@@ -35,6 +35,7 @@ from ...loaders import (
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -864,6 +865,67 @@ class StableDiffusionXLImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
@@ -36,6 +36,7 @@ from ...loaders import (
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -1084,6 +1085,67 @@ class StableDiffusionXLInpaintPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
@@ -19,8 +19,8 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput
from ...utils.accelerate_utils import apply_forward_hook
+3 -10
View File
@@ -138,17 +138,11 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
# Try to retrive config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"]
alpha_retrieved = True
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
@@ -160,8 +154,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None,
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
if not alpha_retrieved:
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
@@ -172,7 +165,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None,
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
elif not alpha_retrieved:
else:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific
+3 -69
View File
@@ -107,9 +107,8 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None
vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
def get_dummy_components(self, scheduler_cls=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
lora_alpha = 4 if lora_alpha is None else lora_alpha
torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
@@ -124,14 +123,11 @@ class PeftLoraLoaderMixinTests:
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
text_lora_config = LoraConfig(
r=4,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
)
unet_lora_config = LoraConfig(
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
@@ -718,68 +714,6 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output",
)
def test_if_lora_alpha_is_correctly_parsed(self):
lora_alpha = 8
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.unet.add_adapter(unet_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
# Inference works?
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
unet_state_dict = get_peft_model_state_dict(pipe.unet)
unet_lora_config = pipe.unet.peft_config["default"]
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
text_encoder_2_lora_config=text_encoder_2_lora_config,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
loaded_pipe = self.pipeline_class(**components)
loaded_pipe.load_lora_weights(tmpdirname)
# Inference works?
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
assert (
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for UNet."
assert (
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder."
if self.has_two_text_encoders:
assert (
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder 2."
def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
@@ -661,6 +661,37 @@ class StableDiffusionPipelineFastTests(
output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
), "Disabling of FreeU should lead to results similar to the default pipeline results."
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
sd_pipe.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
sd_pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow
@require_torch_gpu