diff --git a/debug_conversion.py b/debug_conversion.py new file mode 100755 index 0000000000..fa0d58d0ea --- /dev/null +++ b/debug_conversion.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +import json +import os +from diffusers import UNetUnconditionalModel +from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint +from huggingface_hub import hf_hub_download +import torch + +model_id = "fusing/latent-diffusion-celeba-256" +subfolder = "unet" +#model_id = "fusing/unet-ldm-dummy" +#subfolder = None + +checkpoint = "diffusion_model.pt" +config = "config.json" + +if subfolder is not None: + checkpoint = os.path.join(subfolder, checkpoint) + config = os.path.join(subfolder, config) + +original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint)) +config_path = hf_hub_download(model_id, config) + +with open(config_path) as f: + config = json.load(f) + +checkpoint = convert_ldm_checkpoint(original_checkpoint, config) + + +def current_codebase_conversion(): + model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True) + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step) + + return model.state_dict() + + +currently_converted_checkpoint = current_codebase_conversion() +torch.save(currently_converted_checkpoint, 'currently_converted_checkpoint.pt') + + +def diff_between_checkpoints(ch_0, ch_1): + all_layers_included = False + + if not set(ch_0.keys()) == set(ch_1.keys()): + print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})") + for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))): + print(f"\t{key}") + + print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})") + for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))): + print(f"\t{key}") + else: + print("Keys are the same between the two checkpoints") + all_layers_included = True + + keys = ch_0.keys() + non_equal_keys = [] + + if all_layers_included: + for key in keys: + try: + if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()): + non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}') + + except RuntimeError as e: + print(e) + non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}') + + if len(non_equal_keys): + non_equal_keys = '\n\t'.join(non_equal_keys) + print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}") + else: + print("All keys are equal across checkpoints.") + + +diff_between_checkpoints(currently_converted_checkpoint, checkpoint) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/scripts/convert_ldm_original_checkpoint_to_diffusers.py index ec58a773ce..1f402a57f7 100644 --- a/scripts/convert_ldm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ldm_original_checkpoint_to_diffusers.py @@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0): return mapping -def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None): +def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements @@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): - query, key, value = torch.split(old_checkpoint[path], int(old_checkpoint[path].shape[0] / 3)) + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 - checkpoint[path_map['query']] = query - checkpoint[path_map['key']] = key - checkpoint[path_map['value']] = value + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map['query']] = query.reshape(target_shape) + checkpoint[path_map['key']] = key.reshape(target_shape) + checkpoint[path_map['value']] = value.reshape(target_shape) for path in paths: new_path = path['new'] @@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s for replacement in additional_replacements: new_path = new_path.replace(replacement['old'], replacement['new']) - checkpoint[new_path] = old_checkpoint[path['old']] + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path['old']] def convert_ldm_checkpoint(checkpoint, config): @@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config): paths = renew_resnet_paths(resnets) meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'} resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'} - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config) if len(attentions): paths = renew_attention_paths(attentions) @@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config): new_checkpoint, checkpoint, additional_replacements=[meta_path], - attention_paths_to_split=to_split + attention_paths_to_split=to_split, + config=config ) - resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config) attentions_paths = renew_attention_paths(attentions) to_split = { @@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config): 'value': 'mid.attentions.0.value.weight', }, } - assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split) + assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config) for i in range(num_output_blocks): block_id = i // (config['num_res_blocks'] + 1) @@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config): paths = renew_resnet_paths(resnets) meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'} - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config) if ['conv.weight', 'conv.bias'] in output_block_list.values(): index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) @@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config): if len(attentions) == 2: attentions = [] - if len(attentions): paths = renew_attention_paths(attentions) meta_path = { @@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config): new_checkpoint, checkpoint, additional_replacements=[meta_path], - attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None + attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None, + config=config, ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) @@ -296,7 +308,6 @@ if __name__ == "__main__": args = parser.parse_args() - checkpoint = torch.load(args.checkpoint_path) with open(args.config_file) as f: @@ -304,6 +315,3 @@ if __name__ == "__main__": converted_checkpoint = convert_ldm_checkpoint(checkpoint, config) torch.save(checkpoint, args.dump_path) - - -