Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d060d4118f | |||
| 919b726cf4 | |||
| fab7df8498 |
@@ -83,9 +83,9 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
|
|||||||
|
|
||||||
accelerate_version = "not installed"
|
accelerate_version = "not installed"
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
from accelerate import __version__
|
||||||
|
|
||||||
accelerate_version = accelerate.__version__
|
accelerate_version = __version__
|
||||||
|
|
||||||
peft_version = "not installed"
|
peft_version = "not installed"
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||||
@@ -736,7 +736,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
# Instantiate model with empty weights
|
# Instantiate model with empty weights
|
||||||
with accelerate.init_empty_weights():
|
with init_empty_weights():
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||||
@@ -781,7 +781,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
device_map = {"": "cpu"}
|
device_map = {"": "cpu"}
|
||||||
force_hook = False
|
force_hook = False
|
||||||
try:
|
try:
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||||
device_map,
|
device_map,
|
||||||
@@ -811,7 +811,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
" please also re-upload it or open a PR on the original repository."
|
" please also re-upload it or open a PR on the original repository."
|
||||||
)
|
)
|
||||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||||
device_map,
|
device_map,
|
||||||
|
|||||||
@@ -50,8 +50,7 @@ if is_transformers_available():
|
|||||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
from accelerate import dispatch_model, init_empty_weights
|
||||||
from accelerate import dispatch_model
|
|
||||||
from accelerate.hooks import remove_hook_from_module
|
from accelerate.hooks import remove_hook_from_module
|
||||||
from accelerate.utils import compute_module_sizes, get_max_memory
|
from accelerate.utils import compute_module_sizes, get_max_memory
|
||||||
|
|
||||||
@@ -443,7 +442,7 @@ def _load_empty_model(
|
|||||||
subfolder=kwargs.pop("subfolder", None),
|
subfolder=kwargs.pop("subfolder", None),
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
with accelerate.init_empty_weights():
|
with init_empty_weights():
|
||||||
model = class_obj.from_config(config, **unused_kwargs)
|
model = class_obj.from_config(config, **unused_kwargs)
|
||||||
elif is_transformers_model:
|
elif is_transformers_model:
|
||||||
config_class = getattr(class_obj, "config_class", None)
|
config_class = getattr(class_obj, "config_class", None)
|
||||||
@@ -461,7 +460,7 @@ def _load_empty_model(
|
|||||||
revision=kwargs.pop("revision", None),
|
revision=kwargs.pop("revision", None),
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
with accelerate.init_empty_weights():
|
with init_empty_weights():
|
||||||
model = class_obj(config)
|
model = class_obj(config)
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
@@ -529,7 +528,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
|
|||||||
name,
|
name,
|
||||||
is_pipeline_module,
|
is_pipeline_module,
|
||||||
)
|
)
|
||||||
with accelerate.init_empty_weights():
|
with init_empty_weights():
|
||||||
loaded_sub_model = passed_class_obj[name]
|
loaded_sub_model = passed_class_obj[name]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ from .pipeline_loading_utils import (
|
|||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||||
|
|
||||||
|
|
||||||
LIBRARIES = []
|
LIBRARIES = []
|
||||||
@@ -377,16 +377,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return hasattr(module, "_hf_hook") and (
|
return hasattr(module, "_hf_hook") and (
|
||||||
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
isinstance(module._hf_hook, AlignDevicesHook)
|
||||||
or hasattr(module._hf_hook, "hooks")
|
or hasattr(module._hf_hook, "hooks")
|
||||||
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
|
and isinstance(module._hf_hook.hooks[0], AlignDevicesHook)
|
||||||
)
|
)
|
||||||
|
|
||||||
def module_is_offloaded(module):
|
def module_is_offloaded(module):
|
||||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, CpuOffload)
|
||||||
|
|
||||||
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
|
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
|
||||||
pipeline_is_sequentially_offloaded = any(
|
pipeline_is_sequentially_offloaded = any(
|
||||||
@@ -1009,7 +1009,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
for _, model in self.components.items():
|
for _, model in self.components.items():
|
||||||
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
|
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
|
||||||
accelerate.hooks.remove_hook_from_module(model, recurse=True)
|
remove_hook_from_module(model, recurse=True)
|
||||||
self._all_hooks = []
|
self._all_hooks = []
|
||||||
|
|
||||||
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from .import_utils import is_accelerate_available
|
|||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
import accelerate
|
from accelerate import __version__
|
||||||
|
|
||||||
|
|
||||||
def apply_forward_hook(method):
|
def apply_forward_hook(method):
|
||||||
@@ -36,7 +36,7 @@ def apply_forward_hook(method):
|
|||||||
"""
|
"""
|
||||||
if not is_accelerate_available():
|
if not is_accelerate_available():
|
||||||
return method
|
return method
|
||||||
accelerate_version = version.parse(accelerate.__version__).base_version
|
accelerate_version = version.parse(__version__).base_version
|
||||||
if version.parse(accelerate_version) < version.parse("0.17.0"):
|
if version.parse(accelerate_version) < version.parse("0.17.0"):
|
||||||
return method
|
return method
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user