Update conversion script to correctly handle SD 2 (#1511)
* Conversion SD 2 * finish
This commit is contained in:
committed by
GitHub
parent
22b9cb086b
commit
f21415d1d9
@@ -33,6 +33,7 @@ from diffusers import (
|
|||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
EulerDiscreteScheduler,
|
EulerDiscreteScheduler,
|
||||||
|
HeunDiscreteScheduler,
|
||||||
LDMTextToImagePipeline,
|
LDMTextToImagePipeline,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|||||||
|
|
||||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||||
|
|
||||||
|
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||||
|
use_linear_projection = (
|
||||||
|
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||||
|
)
|
||||||
|
if use_linear_projection:
|
||||||
|
# stable diffusion 2-base-512 and 2-768
|
||||||
|
if head_dim is None:
|
||||||
|
head_dim = [5, 10, 20, 20]
|
||||||
|
|
||||||
config = dict(
|
config = dict(
|
||||||
sample_size=image_size // vae_scale_factor,
|
sample_size=image_size // vae_scale_factor,
|
||||||
in_channels=unet_params.in_channels,
|
in_channels=unet_params.in_channels,
|
||||||
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|||||||
block_out_channels=tuple(block_out_channels),
|
block_out_channels=tuple(block_out_channels),
|
||||||
layers_per_block=unet_params.num_res_blocks,
|
layers_per_block=unet_params.num_res_blocks,
|
||||||
cross_attention_dim=unet_params.context_dim,
|
cross_attention_dim=unet_params.context_dim,
|
||||||
attention_head_dim=unet_params.num_heads,
|
attention_head_dim=head_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
|||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
|
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||||
|
|
||||||
|
# SKIP for now - need openclip -> HF conversion script here
|
||||||
|
# keys = list(checkpoint.keys())
|
||||||
|
#
|
||||||
|
# text_model_dict = {}
|
||||||
|
# for key in keys:
|
||||||
|
# if key.startswith("cond_stage_model.model.transformer"):
|
||||||
|
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
|
||||||
|
#
|
||||||
|
# text_model.load_state_dict(text_model_dict)
|
||||||
|
|
||||||
|
return text_model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -657,13 +684,22 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image_size",
|
"--image_size",
|
||||||
default=512,
|
default=None,
|
||||||
type=int,
|
type=int,
|
||||||
help=(
|
help=(
|
||||||
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
|
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
|
||||||
" Base. Use 768 for Stable Diffusion v2."
|
" Base. Use 768 for Stable Diffusion v2."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prediction_type",
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
help=(
|
||||||
|
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
|
||||||
|
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--extract_ema",
|
"--extract_ema",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -674,10 +710,26 @@ if __name__ == "__main__":
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
image_size = args.image_size
|
||||||
|
prediction_type = args.prediction_type
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint_path)
|
||||||
|
global_step = checkpoint["global_step"]
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
if args.original_config_file is None:
|
if args.original_config_file is None:
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
|
||||||
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
|
# model_type = "v2"
|
||||||
|
os.system(
|
||||||
|
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
|
)
|
||||||
|
args.original_config_file = "./v2-inference-v.yaml"
|
||||||
|
else:
|
||||||
|
# model_type = "v1"
|
||||||
os.system(
|
os.system(
|
||||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||||
)
|
)
|
||||||
@@ -685,54 +737,69 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
original_config = OmegaConf.load(args.original_config_file)
|
original_config = OmegaConf.load(args.original_config_file)
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint_path)
|
if (
|
||||||
checkpoint = checkpoint["state_dict"]
|
"parameterization" in original_config["model"]["params"]
|
||||||
|
and original_config["model"]["params"]["parameterization"] == "v"
|
||||||
|
):
|
||||||
|
if prediction_type is None:
|
||||||
|
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||||
|
if image_size is None:
|
||||||
|
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
image_size = 512 if global_step == 875000 else 768
|
||||||
|
else:
|
||||||
|
if prediction_type is None:
|
||||||
|
prediction_type = "epsilon"
|
||||||
|
if image_size is None:
|
||||||
|
image_size = 512
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
beta_start = original_config.model.params.linear_start
|
beta_start = original_config.model.params.linear_start
|
||||||
beta_end = original_config.model.params.linear_end
|
beta_end = original_config.model.params.linear_end
|
||||||
if args.scheduler_type == "pndm":
|
|
||||||
scheduler = PNDMScheduler(
|
scheduler = DDIMScheduler(
|
||||||
beta_end=beta_end,
|
beta_end=beta_end,
|
||||||
beta_schedule="scaled_linear",
|
beta_schedule="scaled_linear",
|
||||||
beta_start=beta_start,
|
beta_start=beta_start,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
skip_prk_steps=True,
|
steps_offset=1,
|
||||||
)
|
|
||||||
elif args.scheduler_type == "lms":
|
|
||||||
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
|
||||||
elif args.scheduler_type == "euler":
|
|
||||||
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
|
||||||
elif args.scheduler_type == "euler-ancestral":
|
|
||||||
scheduler = EulerAncestralDiscreteScheduler(
|
|
||||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
|
||||||
)
|
|
||||||
elif args.scheduler_type == "dpm":
|
|
||||||
scheduler = DPMSolverMultistepScheduler(
|
|
||||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
|
||||||
)
|
|
||||||
elif args.scheduler_type == "ddim":
|
|
||||||
scheduler = DDIMScheduler(
|
|
||||||
beta_start=beta_start,
|
|
||||||
beta_end=beta_end,
|
|
||||||
beta_schedule="scaled_linear",
|
|
||||||
clip_sample=False,
|
clip_sample=False,
|
||||||
set_alpha_to_one=False,
|
set_alpha_to_one=False,
|
||||||
|
prediction_type=prediction_type,
|
||||||
)
|
)
|
||||||
|
if args.scheduler_type == "pndm":
|
||||||
|
config = dict(scheduler.config)
|
||||||
|
config["skip_prk_steps"] = True
|
||||||
|
scheduler = PNDMScheduler.from_config(config)
|
||||||
|
elif args.scheduler_type == "lms":
|
||||||
|
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif args.scheduler_type == "heun":
|
||||||
|
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif args.scheduler_type == "euler":
|
||||||
|
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif args.scheduler_type == "euler-ancestral":
|
||||||
|
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif args.scheduler_type == "dpm":
|
||||||
|
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||||
|
elif args.scheduler_type == "ddim":
|
||||||
|
scheduler = scheduler
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
|
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||||
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
|
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||||
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
||||||
)
|
)
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
@@ -740,7 +807,20 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
if text_model_type == "FrozenCLIPEmbedder":
|
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||||
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||||
|
pipe = StableDiffusionPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
elif text_model_type == "FrozenCLIPEmbedder":
|
||||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
|||||||
Reference in New Issue
Block a user