|
|
|
@@ -23,9 +23,7 @@ from torch.nn.modules.normalization import GroupNorm
|
|
|
|
|
|
|
|
|
|
from ..configuration_utils import ConfigMixin, register_to_config
|
|
|
|
|
from ..utils import BaseOutput, logging
|
|
|
|
|
from .attention_processor import (
|
|
|
|
|
AttentionProcessor,
|
|
|
|
|
)
|
|
|
|
|
from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor
|
|
|
|
|
from .autoencoders import AutoencoderKL
|
|
|
|
|
from .lora import LoRACompatibleConv
|
|
|
|
|
from .modeling_utils import ModelMixin
|
|
|
|
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
|
|
|
|
|
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
|
|
|
|
|
norm_kwargs["num_channels"] += by # surgery done here
|
|
|
|
|
# conv1
|
|
|
|
|
conv1_args = (
|
|
|
|
|
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
|
|
|
|
|
)
|
|
|
|
|
conv1_args = [
|
|
|
|
|
"in_channels",
|
|
|
|
|
"out_channels",
|
|
|
|
|
"kernel_size",
|
|
|
|
|
"stride",
|
|
|
|
|
"padding",
|
|
|
|
|
"dilation",
|
|
|
|
|
"groups",
|
|
|
|
|
"bias",
|
|
|
|
|
"padding_mode",
|
|
|
|
|
]
|
|
|
|
|
if not USE_PEFT_BACKEND:
|
|
|
|
|
conv1_args.append("lora_layer")
|
|
|
|
|
|
|
|
|
|
for a in conv1_args:
|
|
|
|
|
assert hasattr(old_conv1, a)
|
|
|
|
|
|
|
|
|
|
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
|
|
|
|
|
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
|
|
|
|
|
conv1_kwargs["in_channels"] += by # surgery done here
|
|
|
|
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
|
|
|
|
|
}
|
|
|
|
|
# swap old with new modules
|
|
|
|
|
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
|
|
|
|
|
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
|
|
|
|
|
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
|
|
|
|
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
|
|
|
|
|
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
|
|
|
|
|
)
|
|
|
|
|
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
|
|
|
|
|
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
|
|
|
|
)
|
|
|
|
|
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
|
|
|
|
|
"""Increase channels sizes to allow for additional concatted information from base model"""
|
|
|
|
|
old_down = unet.down_blocks[block_no].downsamplers[0].conv
|
|
|
|
|
# conv1
|
|
|
|
|
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
|
|
|
|
|
" "
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
args = [
|
|
|
|
|
"in_channels",
|
|
|
|
|
"out_channels",
|
|
|
|
|
"kernel_size",
|
|
|
|
|
"stride",
|
|
|
|
|
"padding",
|
|
|
|
|
"dilation",
|
|
|
|
|
"groups",
|
|
|
|
|
"bias",
|
|
|
|
|
"padding_mode",
|
|
|
|
|
]
|
|
|
|
|
if not USE_PEFT_BACKEND:
|
|
|
|
|
args.append("lora_layer")
|
|
|
|
|
|
|
|
|
|
for a in args:
|
|
|
|
|
assert hasattr(old_down, a)
|
|
|
|
|
kwargs = {a: getattr(old_down, a) for a in args}
|
|
|
|
|
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
|
|
|
|
|
kwargs["in_channels"] += by # surgery done here
|
|
|
|
|
# swap old with new modules
|
|
|
|
|
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
|
|
|
|
|
unet.down_blocks[block_no].downsamplers[0].conv = (
|
|
|
|
|
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
|
|
|
|
|
)
|
|
|
|
|
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
|
|
|
|
|
assert hasattr(old_norm1, a)
|
|
|
|
|
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
|
|
|
|
|
norm_kwargs["num_channels"] += by # surgery done here
|
|
|
|
|
# conv1
|
|
|
|
|
conv1_args = (
|
|
|
|
|
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
|
|
|
|
|
)
|
|
|
|
|
for a in conv1_args:
|
|
|
|
|
assert hasattr(old_conv1, a)
|
|
|
|
|
conv1_args = [
|
|
|
|
|
"in_channels",
|
|
|
|
|
"out_channels",
|
|
|
|
|
"kernel_size",
|
|
|
|
|
"stride",
|
|
|
|
|
"padding",
|
|
|
|
|
"dilation",
|
|
|
|
|
"groups",
|
|
|
|
|
"bias",
|
|
|
|
|
"padding_mode",
|
|
|
|
|
]
|
|
|
|
|
if not USE_PEFT_BACKEND:
|
|
|
|
|
conv1_args.append("lora_layer")
|
|
|
|
|
|
|
|
|
|
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
|
|
|
|
|
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
|
|
|
|
|
conv1_kwargs["in_channels"] += by # surgery done here
|
|
|
|
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
|
|
|
|
|
}
|
|
|
|
|
# swap old with new modules
|
|
|
|
|
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
|
|
|
|
|
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
|
|
|
|
|
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
|
|
|
|
unet.mid_block.resnets[0].conv1 = (
|
|
|
|
|
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
|
|
|
|
|
)
|
|
|
|
|
unet.mid_block.resnets[0].conv_shortcut = (
|
|
|
|
|
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
|
|
|
|
)
|
|
|
|
|
unet.mid_block.resnets[0].in_channels += by # surgery done here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|