Update conversion script to correctly handle SD 2 (#1511)
* Conversion SD 2 * finish
This commit is contained in:
committed by
GitHub
parent
22b9cb086b
commit
f21415d1d9
@@ -33,6 +33,7 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim = [5, 10, 20, 20]
|
||||
|
||||
config = dict(
|
||||
sample_size=image_size // vae_scale_factor,
|
||||
in_channels=unet_params.in_channels,
|
||||
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
layers_per_block=unet_params.num_res_blocks,
|
||||
cross_attention_dim=unet_params.context_dim,
|
||||
attention_head_dim=unet_params.num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
return config
|
||||
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
||||
return text_model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
|
||||
# SKIP for now - need openclip -> HF conversion script here
|
||||
# keys = list(checkpoint.keys())
|
||||
#
|
||||
# text_model_dict = {}
|
||||
# for key in keys:
|
||||
# if key.startswith("cond_stage_model.model.transformer"):
|
||||
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
|
||||
#
|
||||
# text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -657,13 +684,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=512,
|
||||
default=None,
|
||||
type=int,
|
||||
help=(
|
||||
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
|
||||
" Base. Use 768 for Stable Diffusion v2."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prediction_type",
|
||||
default=None,
|
||||
type=int,
|
||||
help=(
|
||||
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
|
||||
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract_ema",
|
||||
action="store_true",
|
||||
@@ -674,65 +710,96 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
image_size = args.image_size
|
||||
prediction_type = args.prediction_type
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
global_step = checkpoint["global_step"]
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if args.original_config_file is None:
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
args.original_config_file = "./v2-inference-v.yaml"
|
||||
else:
|
||||
# model_type = "v1"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
# as it relies on a brittle global step parameter here
|
||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||
if image_size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
image_size = 512 if global_step == 875000 else 768
|
||||
else:
|
||||
if prediction_type is None:
|
||||
prediction_type = "epsilon"
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
if args.scheduler_type == "pndm":
|
||||
scheduler = PNDMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif args.scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
elif args.scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif args.scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif args.scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
||||
)
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif args.scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
||||
)
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif args.scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
scheduler = scheduler
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
@@ -740,7 +807,20 @@ if __name__ == "__main__":
|
||||
|
||||
# Convert the text model.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenCLIPEmbedder":
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
elif text_model_type == "FrozenCLIPEmbedder":
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
|
||||
Reference in New Issue
Block a user