controlnet sd 2.1 checkpoint conversions (#2593)
* controlnet sd 2.1 checkpoint conversions * remove global_step -> make config file mandatory
This commit is contained in:
@@ -0,0 +1,91 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--original_config_file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The YAML config file corresponding to the original architecture.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_in_channels",
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_size",
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help=(
|
||||||
|
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
|
||||||
|
" Base. Use 768 for Stable Diffusion v2."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--extract_ema",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
||||||
|
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
||||||
|
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upcast_attention",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Whether the attention computation should always be upcasted. This is necessary when running stable"
|
||||||
|
" diffusion 2.1."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from_safetensors",
|
||||||
|
action="store_true",
|
||||||
|
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to_safetensors",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to store pipeline in safetensors format or not.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
|
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
controlnet = download_controlnet_from_original_ckpt(
|
||||||
|
checkpoint_path=args.checkpoint_path,
|
||||||
|
original_config_file=args.original_config_file,
|
||||||
|
image_size=args.image_size,
|
||||||
|
extract_ema=args.extract_ema,
|
||||||
|
num_in_channels=args.num_in_channels,
|
||||||
|
upcast_attention=args.upcast_attention,
|
||||||
|
from_safetensors=args.from_safetensors,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||||
@@ -954,6 +954,25 @@ def stable_unclip_image_noising_components(
|
|||||||
return image_normalizer, image_noising_scheduler
|
return image_normalizer, image_noising_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def convert_controlnet_checkpoint(
|
||||||
|
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||||
|
):
|
||||||
|
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||||
|
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||||
|
|
||||||
|
ctrlnet_config.pop("sample_size")
|
||||||
|
|
||||||
|
controlnet_model = ControlNetModel(**ctrlnet_config)
|
||||||
|
|
||||||
|
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
||||||
|
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
||||||
|
|
||||||
|
return controlnet_model
|
||||||
|
|
||||||
|
|
||||||
def download_from_original_stable_diffusion_ckpt(
|
def download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
original_config_file: str = None,
|
original_config_file: str = None,
|
||||||
@@ -1042,7 +1061,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
print("global_step key not found in model")
|
print("global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
if "state_dict" in checkpoint:
|
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||||
|
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||||
|
while "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
if original_config_file is None:
|
if original_config_file is None:
|
||||||
@@ -1084,6 +1105,14 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
if image_size is None:
|
if image_size is None:
|
||||||
image_size = 512
|
image_size = 512
|
||||||
|
|
||||||
|
if controlnet is None:
|
||||||
|
controlnet = "control_stage_config" in original_config.model.params
|
||||||
|
|
||||||
|
if controlnet:
|
||||||
|
controlnet_model = convert_controlnet_checkpoint(
|
||||||
|
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||||
|
)
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
beta_start = original_config.model.params.linear_start
|
beta_start = original_config.model.params.linear_start
|
||||||
beta_end = original_config.model.params.linear_end
|
beta_end = original_config.model.params.linear_end
|
||||||
@@ -1143,17 +1172,24 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||||
|
|
||||||
if controlnet is None:
|
|
||||||
controlnet = "control_stage_config" in original_config.model.params
|
|
||||||
|
|
||||||
if controlnet and model_type != "FrozenCLIPEmbedder":
|
|
||||||
raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'")
|
|
||||||
|
|
||||||
if model_type == "FrozenOpenCLIPEmbedder":
|
if model_type == "FrozenOpenCLIPEmbedder":
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||||
|
|
||||||
if stable_unclip is None:
|
if stable_unclip is None:
|
||||||
|
if controlnet:
|
||||||
|
pipe = StableDiffusionControlNetPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
controlnet=controlnet_model,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
pipe = StableDiffusionPipeline(
|
pipe = StableDiffusionPipeline(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model,
|
||||||
@@ -1238,19 +1274,6 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
|
||||||
if controlnet:
|
if controlnet:
|
||||||
# Convert the ControlNetModel model.
|
|
||||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
|
||||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
|
||||||
|
|
||||||
ctrlnet_config.pop("sample_size")
|
|
||||||
|
|
||||||
controlnet_model = ControlNetModel(**ctrlnet_config)
|
|
||||||
|
|
||||||
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
|
||||||
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
|
||||||
)
|
|
||||||
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
|
||||||
|
|
||||||
pipe = StableDiffusionControlNetPipeline(
|
pipe = StableDiffusionControlNetPipeline(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model,
|
||||||
@@ -1278,3 +1301,55 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def download_controlnet_from_original_ckpt(
|
||||||
|
checkpoint_path: str,
|
||||||
|
original_config_file: str,
|
||||||
|
image_size: int = 512,
|
||||||
|
extract_ema: bool = False,
|
||||||
|
num_in_channels: Optional[int] = None,
|
||||||
|
upcast_attention: Optional[bool] = None,
|
||||||
|
device: str = None,
|
||||||
|
from_safetensors: bool = False,
|
||||||
|
) -> StableDiffusionPipeline:
|
||||||
|
if not is_omegaconf_available():
|
||||||
|
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
if from_safetensors:
|
||||||
|
if not is_safetensors_available():
|
||||||
|
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||||
|
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
checkpoint = {}
|
||||||
|
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
checkpoint[key] = f.get_tensor(key)
|
||||||
|
else:
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||||
|
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||||
|
while "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
|
if num_in_channels is not None:
|
||||||
|
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||||
|
|
||||||
|
if "control_stage_config" not in original_config.model.params:
|
||||||
|
raise ValueError("`control_stage_config` not present in original config")
|
||||||
|
|
||||||
|
controlnet_model = convert_controlnet_checkpoint(
|
||||||
|
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||||
|
)
|
||||||
|
|
||||||
|
return controlnet_model
|
||||||
|
|||||||
Reference in New Issue
Block a user