[BUG] Fix convert_vae_pt_to_diffusers bug (#11078)
* fix attention * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
31c4f24fc1
commit
e121d0ef67
@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
resnets = [
|
||||
key
|
||||
for key in down_blocks[i]
|
||||
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
|
||||
]
|
||||
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
key
|
||||
for key in up_blocks[block_id]
|
||||
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
|
||||
]
|
||||
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
|
||||
@ -350,8 +350,14 @@ def create_vae_diffusers_config(original_config, image_size: int):
|
||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
down_block_types = [
|
||||
"DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D"
|
||||
for i, _ in enumerate(block_out_channels)
|
||||
]
|
||||
up_block_types = [
|
||||
"UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D"
|
||||
for i, _ in enumerate(block_out_channels)
|
||||
][::-1]
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user