|
|
|
@@ -12,23 +12,26 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
|
|
|
|
Creates a config for the diffusers based on the config of the LDM model.
|
|
|
|
|
"""
|
|
|
|
|
if controlnet:
|
|
|
|
|
unet_params = original_config.model.params.control_stage_config.params
|
|
|
|
|
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
|
|
|
|
|
else:
|
|
|
|
|
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
|
|
|
|
unet_params = original_config.model.params.unet_config.params
|
|
|
|
|
if (
|
|
|
|
|
"unet_config" in original_config["model"]["params"]
|
|
|
|
|
and original_config["model"]["params"]["unet_config"] is not None
|
|
|
|
|
):
|
|
|
|
|
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
|
|
|
|
else:
|
|
|
|
|
unet_params = original_config.model.params.network_config.params
|
|
|
|
|
unet_params = original_config["model"]["params"]["network_config"]["params"]
|
|
|
|
|
|
|
|
|
|
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
|
|
|
|
|
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["encoder_config"]["params"]
|
|
|
|
|
|
|
|
|
|
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
|
|
|
|
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
|
|
|
|
|
|
|
|
|
down_block_types = []
|
|
|
|
|
resolution = 1
|
|
|
|
|
for i in range(len(block_out_channels)):
|
|
|
|
|
block_type = (
|
|
|
|
|
"CrossAttnDownBlockSpatioTemporal"
|
|
|
|
|
if resolution in unet_params.attention_resolutions
|
|
|
|
|
if resolution in unet_params["attention_resolutions"]
|
|
|
|
|
else "DownBlockSpatioTemporal"
|
|
|
|
|
)
|
|
|
|
|
down_block_types.append(block_type)
|
|
|
|
@@ -39,32 +42,32 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
|
|
|
|
for i in range(len(block_out_channels)):
|
|
|
|
|
block_type = (
|
|
|
|
|
"CrossAttnUpBlockSpatioTemporal"
|
|
|
|
|
if resolution in unet_params.attention_resolutions
|
|
|
|
|
if resolution in unet_params["attention_resolutions"]
|
|
|
|
|
else "UpBlockSpatioTemporal"
|
|
|
|
|
)
|
|
|
|
|
up_block_types.append(block_type)
|
|
|
|
|
resolution //= 2
|
|
|
|
|
|
|
|
|
|
if unet_params.transformer_depth is not None:
|
|
|
|
|
if unet_params["transformer_depth"] is not None:
|
|
|
|
|
transformer_layers_per_block = (
|
|
|
|
|
unet_params.transformer_depth
|
|
|
|
|
if isinstance(unet_params.transformer_depth, int)
|
|
|
|
|
else list(unet_params.transformer_depth)
|
|
|
|
|
unet_params["transformer_depth"]
|
|
|
|
|
if isinstance(unet_params["transformer_depth"], int)
|
|
|
|
|
else list(unet_params["transformer_depth"])
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
transformer_layers_per_block = 1
|
|
|
|
|
|
|
|
|
|
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
|
|
|
|
|
|
|
|
|
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
|
|
|
|
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
|
|
|
|
|
use_linear_projection = (
|
|
|
|
|
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
|
|
|
|
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
|
|
|
|
|
)
|
|
|
|
|
if use_linear_projection:
|
|
|
|
|
# stable diffusion 2-base-512 and 2-768
|
|
|
|
|
if head_dim is None:
|
|
|
|
|
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
|
|
|
|
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
|
|
|
|
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
|
|
|
|
|
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
|
|
|
|
|
|
|
|
|
|
class_embed_type = None
|
|
|
|
|
addition_embed_type = None
|
|
|
|
@@ -72,23 +75,25 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
|
|
|
|
projection_class_embeddings_input_dim = None
|
|
|
|
|
context_dim = None
|
|
|
|
|
|
|
|
|
|
if unet_params.context_dim is not None:
|
|
|
|
|
if unet_params["context_dim"] is not None:
|
|
|
|
|
context_dim = (
|
|
|
|
|
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
|
|
|
|
unet_params["context_dim"]
|
|
|
|
|
if isinstance(unet_params["context_dim"], int)
|
|
|
|
|
else unet_params["context_dim"][0]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if "num_classes" in unet_params:
|
|
|
|
|
if unet_params.num_classes == "sequential":
|
|
|
|
|
if unet_params["num_classes"] == "sequential":
|
|
|
|
|
addition_time_embed_dim = 256
|
|
|
|
|
assert "adm_in_channels" in unet_params
|
|
|
|
|
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
|
|
|
|
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
|
|
|
|
|
|
|
|
|
|
config = {
|
|
|
|
|
"sample_size": image_size // vae_scale_factor,
|
|
|
|
|
"in_channels": unet_params.in_channels,
|
|
|
|
|
"in_channels": unet_params["in_channels"],
|
|
|
|
|
"down_block_types": tuple(down_block_types),
|
|
|
|
|
"block_out_channels": tuple(block_out_channels),
|
|
|
|
|
"layers_per_block": unet_params.num_res_blocks,
|
|
|
|
|
"layers_per_block": unet_params["num_res_blocks"],
|
|
|
|
|
"cross_attention_dim": context_dim,
|
|
|
|
|
"attention_head_dim": head_dim,
|
|
|
|
|
"use_linear_projection": use_linear_projection,
|
|
|
|
@@ -100,15 +105,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if "disable_self_attentions" in unet_params:
|
|
|
|
|
config["only_cross_attention"] = unet_params.disable_self_attentions
|
|
|
|
|
config["only_cross_attention"] = unet_params["disable_self_attentions"]
|
|
|
|
|
|
|
|
|
|
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
|
|
|
|
config["num_class_embeds"] = unet_params.num_classes
|
|
|
|
|
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
|
|
|
|
|
config["num_class_embeds"] = unet_params["num_classes"]
|
|
|
|
|
|
|
|
|
|
if controlnet:
|
|
|
|
|
config["conditioning_channels"] = unet_params.hint_channels
|
|
|
|
|
config["conditioning_channels"] = unet_params["hint_channels"]
|
|
|
|
|
else:
|
|
|
|
|
config["out_channels"] = unet_params.out_channels
|
|
|
|
|
config["out_channels"] = unet_params["out_channels"]
|
|
|
|
|
config["up_block_types"] = tuple(up_block_types)
|
|
|
|
|
|
|
|
|
|
return config
|
|
|
|
|