[BUG] Fix convert_vae_pt_to_diffusers bug (#11078)

* fix attention

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Yuqian Hong 2025-04-10 13:59:45 +08:00 committed by GitHub
parent 31c4f24fc1
commit e121d0ef67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 4 deletions

View File

@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
} }
for i in range(num_down_blocks): for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] resnets = [
key
for key in down_blocks[i]
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
]
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
paths = renew_vae_attention_paths(attentions)
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2 num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1): for i in range(1, num_mid_res_blocks + 1):
@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
resnets = [ resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key key
for key in up_blocks[block_id]
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
] ]
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
paths = renew_vae_attention_paths(attentions)
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2 num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1): for i in range(1, num_mid_res_blocks + 1):

View File

@ -350,8 +350,14 @@ def create_vae_diffusers_config(original_config, image_size: int):
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = [
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) "DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D"
for i, _ in enumerate(block_out_channels)
]
up_block_types = [
"UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D"
for i, _ in enumerate(block_out_channels)
][::-1]
config = { config = {
"sample_size": image_size, "sample_size": image_size,