Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5fa46f4b01 | |||
| a10cd21b04 |
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -970,3 +971,105 @@ class StableDiffusionXLAdapterPipeline(
|
|||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
return StableDiffusionXLPipelineOutput(images=image)
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||||
|
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||||
|
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||||
|
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||||
|
# pipeline.
|
||||||
|
|
||||||
|
# Remove any existing hooks.
|
||||||
|
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||||
|
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||||
|
else:
|
||||||
|
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
|
||||||
|
|
||||||
|
is_model_cpu_offload = False
|
||||||
|
is_sequential_cpu_offload = False
|
||||||
|
recursive = False
|
||||||
|
for _, component in self.components.items():
|
||||||
|
if isinstance(component, torch.nn.Module):
|
||||||
|
if hasattr(component, "_hf_hook"):
|
||||||
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||||
|
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||||
|
logger.info(
|
||||||
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||||
|
)
|
||||||
|
recursive = is_sequential_cpu_offload
|
||||||
|
remove_hook_from_module(component, recurse=recursive)
|
||||||
|
state_dict, network_alphas = self.lora_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
unet_config=self.unet.config,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||||
|
|
||||||
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||||
|
if len(text_encoder_state_dict) > 0:
|
||||||
|
self.load_lora_into_text_encoder(
|
||||||
|
text_encoder_state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
prefix="text_encoder",
|
||||||
|
lora_scale=self.lora_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||||
|
if len(text_encoder_2_state_dict) > 0:
|
||||||
|
self.load_lora_into_text_encoder(
|
||||||
|
text_encoder_2_state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
text_encoder=self.text_encoder_2,
|
||||||
|
prefix="text_encoder_2",
|
||||||
|
lora_scale=self.lora_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Offload back.
|
||||||
|
if is_model_cpu_offload:
|
||||||
|
self.enable_model_cpu_offload()
|
||||||
|
elif is_sequential_cpu_offload:
|
||||||
|
self.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_lora_weights(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
is_main_process: bool = True,
|
||||||
|
weight_name: str = None,
|
||||||
|
save_function: Callable = None,
|
||||||
|
safe_serialization: bool = True,
|
||||||
|
):
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
def pack_weights(layers, prefix):
|
||||||
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||||
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||||
|
return layers_state_dict
|
||||||
|
|
||||||
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||||
|
raise ValueError(
|
||||||
|
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if unet_lora_layers:
|
||||||
|
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||||
|
|
||||||
|
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||||
|
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||||
|
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||||
|
|
||||||
|
self.write_lora_layers(
|
||||||
|
state_dict=state_dict,
|
||||||
|
save_directory=save_directory,
|
||||||
|
is_main_process=is_main_process,
|
||||||
|
weight_name=weight_name,
|
||||||
|
save_function=save_function,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remove_text_encoder_monkey_patch(self):
|
||||||
|
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||||
|
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||||
|
|||||||
Reference in New Issue
Block a user