From e121d0ef675334ac4a51d0a38b767dcf4824ac30 Mon Sep 17 00:00:00 2001 From: Yuqian Hong Date: Thu, 10 Apr 2025 13:59:45 +0800 Subject: [PATCH] [BUG] Fix convert_vae_pt_to_diffusers bug (#11078) * fix attention * Apply style fixes --------- Co-authored-by: github-actions[bot] --- scripts/convert_vae_pt_to_diffusers.py | 20 +++++++++++++++++-- .../stable_diffusion/convert_from_ckpt.py | 10 ++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/scripts/convert_vae_pt_to_diffusers.py b/scripts/convert_vae_pt_to_diffusers.py index 13ceca40f3..8c7dc71ddf 100644 --- a/scripts/convert_vae_pt_to_diffusers.py +++ b/scripts/convert_vae_pt_to_diffusers.py @@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config): } for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key + ] + attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( @@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config): meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + paths = renew_vae_attention_paths(attentions) + meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): @@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config): for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key ] + attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ @@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config): meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + paths = renew_vae_attention_paths(attentions) + meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index d337aba8e9..568ae7f7d6 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -350,8 +350,14 @@ def create_vae_diffusers_config(original_config, image_size: int): _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + down_block_types = [ + "DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D" + for i, _ in enumerate(block_out_channels) + ] + up_block_types = [ + "UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D" + for i, _ in enumerate(block_out_channels) + ][::-1] config = { "sample_size": image_size,