[BUG] Fix convert_vae_pt_to_diffusers bug (#11078)
* fix attention * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
31c4f24fc1
commit
e121d0ef67
@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i in range(num_down_blocks):
|
for i in range(num_down_blocks):
|
||||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
resnets = [
|
||||||
|
key
|
||||||
|
for key in down_blocks[i]
|
||||||
|
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
|
||||||
|
]
|
||||||
|
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
|
||||||
|
|
||||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||||
@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
paths = renew_vae_attention_paths(attentions)
|
||||||
|
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||||
num_mid_res_blocks = 2
|
num_mid_res_blocks = 2
|
||||||
for i in range(1, num_mid_res_blocks + 1):
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
for i in range(num_up_blocks):
|
for i in range(num_up_blocks):
|
||||||
block_id = num_up_blocks - 1 - i
|
block_id = num_up_blocks - 1 - i
|
||||||
resnets = [
|
resnets = [
|
||||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
key
|
||||||
|
for key in up_blocks[block_id]
|
||||||
|
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
|
||||||
]
|
]
|
||||||
|
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
|
||||||
|
|
||||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||||
@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
paths = renew_vae_attention_paths(attentions)
|
||||||
|
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||||
num_mid_res_blocks = 2
|
num_mid_res_blocks = 2
|
||||||
for i in range(1, num_mid_res_blocks + 1):
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
|||||||
@ -350,8 +350,14 @@ def create_vae_diffusers_config(original_config, image_size: int):
|
|||||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||||
|
|
||||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
down_block_types = [
|
||||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
"DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D"
|
||||||
|
for i, _ in enumerate(block_out_channels)
|
||||||
|
]
|
||||||
|
up_block_types = [
|
||||||
|
"UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D"
|
||||||
|
for i, _ in enumerate(block_out_channels)
|
||||||
|
][::-1]
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"sample_size": image_size,
|
"sample_size": image_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user