Fix conversion script
This commit is contained in:
parent
87060e6a9c
commit
3f1e95928e
86
debug_conversion.py
Executable file
86
debug_conversion.py
Executable file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from diffusers import UNetUnconditionalModel
|
||||||
|
from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model_id = "fusing/latent-diffusion-celeba-256"
|
||||||
|
subfolder = "unet"
|
||||||
|
#model_id = "fusing/unet-ldm-dummy"
|
||||||
|
#subfolder = None
|
||||||
|
|
||||||
|
checkpoint = "diffusion_model.pt"
|
||||||
|
config = "config.json"
|
||||||
|
|
||||||
|
if subfolder is not None:
|
||||||
|
checkpoint = os.path.join(subfolder, checkpoint)
|
||||||
|
config = os.path.join(subfolder, config)
|
||||||
|
|
||||||
|
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint))
|
||||||
|
config_path = hf_hub_download(model_id, config)
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
checkpoint = convert_ldm_checkpoint(original_checkpoint, config)
|
||||||
|
|
||||||
|
|
||||||
|
def current_codebase_conversion():
|
||||||
|
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||||
|
time_step = torch.tensor([10] * noise.shape[0])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(noise, time_step)
|
||||||
|
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
currently_converted_checkpoint = current_codebase_conversion()
|
||||||
|
torch.save(currently_converted_checkpoint, 'currently_converted_checkpoint.pt')
|
||||||
|
|
||||||
|
|
||||||
|
def diff_between_checkpoints(ch_0, ch_1):
|
||||||
|
all_layers_included = False
|
||||||
|
|
||||||
|
if not set(ch_0.keys()) == set(ch_1.keys()):
|
||||||
|
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
|
||||||
|
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
|
||||||
|
print(f"\t{key}")
|
||||||
|
|
||||||
|
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
|
||||||
|
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
|
||||||
|
print(f"\t{key}")
|
||||||
|
else:
|
||||||
|
print("Keys are the same between the two checkpoints")
|
||||||
|
all_layers_included = True
|
||||||
|
|
||||||
|
keys = ch_0.keys()
|
||||||
|
non_equal_keys = []
|
||||||
|
|
||||||
|
if all_layers_included:
|
||||||
|
for key in keys:
|
||||||
|
try:
|
||||||
|
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
|
||||||
|
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(e)
|
||||||
|
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
|
||||||
|
|
||||||
|
if len(non_equal_keys):
|
||||||
|
non_equal_keys = '\n\t'.join(non_equal_keys)
|
||||||
|
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
|
||||||
|
else:
|
||||||
|
print("All keys are equal across checkpoints.")
|
||||||
|
|
||||||
|
|
||||||
|
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
|
||||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
|||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None):
|
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
||||||
"""
|
"""
|
||||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||||
to them. It splits attention layers, and takes into account additional replacements
|
to them. It splits attention layers, and takes into account additional replacements
|
||||||
@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
|||||||
# Splits the attention layers into three variables.
|
# Splits the attention layers into three variables.
|
||||||
if attention_paths_to_split is not None:
|
if attention_paths_to_split is not None:
|
||||||
for path, path_map in attention_paths_to_split.items():
|
for path, path_map in attention_paths_to_split.items():
|
||||||
query, key, value = torch.split(old_checkpoint[path], int(old_checkpoint[path].shape[0] / 3))
|
old_tensor = old_checkpoint[path]
|
||||||
|
channels = old_tensor.shape[0] // 3
|
||||||
|
|
||||||
checkpoint[path_map['query']] = query
|
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||||
checkpoint[path_map['key']] = key
|
|
||||||
checkpoint[path_map['value']] = value
|
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||||
|
|
||||||
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||||
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||||
|
|
||||||
|
checkpoint[path_map['query']] = query.reshape(target_shape)
|
||||||
|
checkpoint[path_map['key']] = key.reshape(target_shape)
|
||||||
|
checkpoint[path_map['value']] = value.reshape(target_shape)
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
new_path = path['new']
|
new_path = path['new']
|
||||||
@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
|||||||
for replacement in additional_replacements:
|
for replacement in additional_replacements:
|
||||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||||
|
|
||||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
|
if "proj_attn.weight" in new_path:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
||||||
|
else:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_checkpoint(checkpoint, config):
|
def convert_ldm_checkpoint(checkpoint, config):
|
||||||
@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||||||
paths = renew_resnet_paths(resnets)
|
paths = renew_resnet_paths(resnets)
|
||||||
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||||
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
|
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
|
||||||
|
|
||||||
if len(attentions):
|
if len(attentions):
|
||||||
paths = renew_attention_paths(attentions)
|
paths = renew_attention_paths(attentions)
|
||||||
@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||||||
new_checkpoint,
|
new_checkpoint,
|
||||||
checkpoint,
|
checkpoint,
|
||||||
additional_replacements=[meta_path],
|
additional_replacements=[meta_path],
|
||||||
attention_paths_to_split=to_split
|
attention_paths_to_split=to_split,
|
||||||
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
resnet_0 = middle_blocks[0]
|
resnet_0 = middle_blocks[0]
|
||||||
attentions = middle_blocks[1]
|
attentions = middle_blocks[1]
|
||||||
resnet_1 = middle_blocks[2]
|
resnet_1 = middle_blocks[2]
|
||||||
|
|
||||||
resnet_0_paths = renew_resnet_paths(resnet_0)
|
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||||
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint)
|
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
|
||||||
|
|
||||||
resnet_1_paths = renew_resnet_paths(resnet_1)
|
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||||
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint)
|
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
|
||||||
|
|
||||||
attentions_paths = renew_attention_paths(attentions)
|
attentions_paths = renew_attention_paths(attentions)
|
||||||
to_split = {
|
to_split = {
|
||||||
@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||||||
'value': 'mid.attentions.0.value.weight',
|
'value': 'mid.attentions.0.value.weight',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split)
|
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
|
||||||
|
|
||||||
for i in range(num_output_blocks):
|
for i in range(num_output_blocks):
|
||||||
block_id = i // (config['num_res_blocks'] + 1)
|
block_id = i // (config['num_res_blocks'] + 1)
|
||||||
@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||||||
paths = renew_resnet_paths(resnets)
|
paths = renew_resnet_paths(resnets)
|
||||||
|
|
||||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
||||||
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
||||||
@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||||||
if len(attentions) == 2:
|
if len(attentions) == 2:
|
||||||
attentions = []
|
attentions = []
|
||||||
|
|
||||||
|
|
||||||
if len(attentions):
|
if len(attentions):
|
||||||
paths = renew_attention_paths(attentions)
|
paths = renew_attention_paths(attentions)
|
||||||
meta_path = {
|
meta_path = {
|
||||||
@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||||||
new_checkpoint,
|
new_checkpoint,
|
||||||
checkpoint,
|
checkpoint,
|
||||||
additional_replacements=[meta_path],
|
additional_replacements=[meta_path],
|
||||||
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None
|
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||||
@ -296,7 +308,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint_path)
|
checkpoint = torch.load(args.checkpoint_path)
|
||||||
|
|
||||||
with open(args.config_file) as f:
|
with open(args.config_file) as f:
|
||||||
@ -304,6 +315,3 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
|
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
|
||||||
torch.save(checkpoint, args.dump_path)
|
torch.save(checkpoint, args.dump_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user