Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d9915a7d65 | |||
| b7a795dbeb | |||
| 438905d63e | |||
| 904f24de5a | |||
| e123bbcbc4 | |||
| b3fa8c695d | |||
| 720be2bac5 | |||
| e74b782aac | |||
| d6392b4b49 | |||
| 1475026960 | |||
| 878eb4ce35 |
@@ -29,11 +29,16 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||
|
||||
# Always use memory-efficient CPU offloading to minimize RAM usage
|
||||
|
||||
_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
@@ -56,7 +61,6 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -68,12 +72,8 @@ class ModuleGroup:
|
||||
self.buffers = buffers
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
@@ -82,23 +82,125 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
# Use the most efficient module-level transfer when possible
|
||||
# This approach mirrors how PyTorch handles full model transfers
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Only onload if some parameters are not on the target device
|
||||
if any(p.device != self.onload_device for p in group_module.parameters()):
|
||||
try:
|
||||
# Try the most efficient approach using _apply
|
||||
if hasattr(group_module, "_apply"):
|
||||
# This is what module.to() uses internally
|
||||
def to_device(t):
|
||||
if t.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
return t.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
return t.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
return t
|
||||
|
||||
# Apply to all tensors without unnecessary copies
|
||||
group_module._apply(to_device)
|
||||
else:
|
||||
# Fallback to direct parameter transfer
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
except Exception as e:
|
||||
# If optimization fails, fall back to direct parameter transfer
|
||||
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle explicit parameters
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if buffer.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
buffer.data = buffer.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
# For CPU offloading
|
||||
if self.offload_device.type == "cpu":
|
||||
# Synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Empty GPU cache before offloading to reduce memory fragmentation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For module groups, use a single, unified approach that is closest to
|
||||
# the behavior of model.to("cpu")
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Check if we need to offload this module
|
||||
if any(p.device.type != "cpu" for p in group_module.parameters()):
|
||||
# Use PyTorch's built-in to() method directly, which preserves
|
||||
# memory mapping when moving to CPU
|
||||
try:
|
||||
# Non-blocking=False for CPU transfers, as it ensures memory is
|
||||
# immediately available and potentially preserves memory mapping
|
||||
group_module.to("cpu", non_blocking=False)
|
||||
except Exception as e:
|
||||
# If there's any error, fall back to parameter-level offloading
|
||||
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
|
||||
for param in group_module.parameters():
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle explicit parameters - move directly to CPU with non-blocking=False
|
||||
# which can preserve memory mapping in some PyTorch versions
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device.type != "cpu":
|
||||
buffer.data = buffer.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Let Python's normal reference counting handle cleanup
|
||||
# We don't force garbage collection to avoid slowing down inference
|
||||
|
||||
else:
|
||||
# For non-CPU offloading, synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# For non-CPU offloading, use the regular approach
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
@@ -108,6 +210,9 @@ class ModuleGroup:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
# After offloading, we can unpin the memory if configured to do so
|
||||
# We'll keep it pinned by default for better performance
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
# Offload to CPU
|
||||
self.group.offload_()
|
||||
return module
|
||||
|
||||
@@ -313,7 +419,8 @@ def apply_group_offloading(
|
||||
If True, offloading and onloading is done with non-blocking data transfer.
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
|
||||
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -344,12 +451,19 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
# We no longer need a pinned group manager as we're not using pinned memory
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module,
|
||||
num_blocks_per_group,
|
||||
offload_device,
|
||||
onload_device,
|
||||
non_blocking,
|
||||
stream,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
# We no longer need a CPU parameter dictionary
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
# We no longer need a CPU parameter dictionary
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -25,6 +25,7 @@ from types import ModuleType
|
||||
from typing import Any, Union
|
||||
|
||||
from huggingface_hub.utils import is_jinja_available # noqa: F401
|
||||
from packaging import version
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from . import logging
|
||||
@@ -51,30 +52,36 @@ DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
||||
|
||||
|
||||
def _is_package_available(pkg_name: str):
|
||||
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
pkg_version = "N/A"
|
||||
|
||||
if pkg_exists:
|
||||
try:
|
||||
pkg_version = importlib_metadata.version(pkg_name)
|
||||
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
|
||||
except (ImportError, importlib_metadata.PackageNotFoundError):
|
||||
pkg_exists = False
|
||||
|
||||
return pkg_exists, pkg_version
|
||||
|
||||
|
||||
_torch_version = "N/A"
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available, _torch_version = _is_package_available("torch")
|
||||
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
if _torch_available:
|
||||
try:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
logger.info(f"PyTorch version {_torch_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TORCH is set")
|
||||
_torch_available = False
|
||||
|
||||
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
|
||||
if _torch_xla_available:
|
||||
try:
|
||||
_torch_xla_version = importlib_metadata.version("torch_xla")
|
||||
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
|
||||
except ImportError:
|
||||
_torch_xla_available = False
|
||||
|
||||
# check whether torch_npu is available
|
||||
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
|
||||
if _torch_npu_available:
|
||||
try:
|
||||
_torch_npu_version = importlib_metadata.version("torch_npu")
|
||||
logger.info(f"torch_npu version {_torch_npu_version} available.")
|
||||
except ImportError:
|
||||
_torch_npu_available = False
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
@@ -90,12 +97,47 @@ else:
|
||||
_flax_available = False
|
||||
|
||||
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
|
||||
|
||||
_safetensors_available = importlib.util.find_spec("safetensors") is not None
|
||||
if _safetensors_available:
|
||||
try:
|
||||
_safetensors_version = importlib_metadata.version("safetensors")
|
||||
logger.info(f"Safetensors version {_safetensors_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_safetensors_available = False
|
||||
else:
|
||||
logger.info("Disabling Safetensors because USE_TF is set")
|
||||
_safetensors_available = False
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
|
||||
try:
|
||||
_hf_hub_version = importlib_metadata.version("huggingface_hub")
|
||||
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_hf_hub_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
_inflect_version = importlib_metadata.version("inflect")
|
||||
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_inflect_available = False
|
||||
|
||||
|
||||
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
||||
try:
|
||||
_unidecode_version = importlib_metadata.version("unidecode")
|
||||
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_unidecode_available = False
|
||||
|
||||
_onnxruntime_version = "N/A"
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
if _onnx_available:
|
||||
@@ -144,6 +186,85 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_opencv_available = False
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
_scipy_version = importlib_metadata.version("scipy")
|
||||
logger.debug(f"Successfully imported scipy version {_scipy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scipy_available = False
|
||||
|
||||
_librosa_available = importlib.util.find_spec("librosa") is not None
|
||||
try:
|
||||
_librosa_version = importlib_metadata.version("librosa")
|
||||
logger.debug(f"Successfully imported librosa version {_librosa_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_librosa_available = False
|
||||
|
||||
_accelerate_available = importlib.util.find_spec("accelerate") is not None
|
||||
try:
|
||||
_accelerate_version = importlib_metadata.version("accelerate")
|
||||
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_accelerate_available = False
|
||||
|
||||
_xformers_available = importlib.util.find_spec("xformers") is not None
|
||||
try:
|
||||
_xformers_version = importlib_metadata.version("xformers")
|
||||
if _torch_available:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
if version.Version(_torch_version) < version.Version("1.12"):
|
||||
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
|
||||
|
||||
logger.debug(f"Successfully imported xformers version {_xformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_xformers_available = False
|
||||
|
||||
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
|
||||
try:
|
||||
_k_diffusion_version = importlib_metadata.version("k_diffusion")
|
||||
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_k_diffusion_available = False
|
||||
|
||||
_note_seq_available = importlib.util.find_spec("note_seq") is not None
|
||||
try:
|
||||
_note_seq_version = importlib_metadata.version("note_seq")
|
||||
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_note_seq_available = False
|
||||
|
||||
_wandb_available = importlib.util.find_spec("wandb") is not None
|
||||
try:
|
||||
_wandb_version = importlib_metadata.version("wandb")
|
||||
logger.debug(f"Successfully imported wandb version {_wandb_version }")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_wandb_available = False
|
||||
|
||||
|
||||
_tensorboard_available = importlib.util.find_spec("tensorboard")
|
||||
try:
|
||||
_tensorboard_version = importlib_metadata.version("tensorboard")
|
||||
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_tensorboard_available = False
|
||||
|
||||
|
||||
_compel_available = importlib.util.find_spec("compel")
|
||||
try:
|
||||
_compel_version = importlib_metadata.version("compel")
|
||||
logger.debug(f"Successfully imported compel version {_compel_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_compel_available = False
|
||||
|
||||
|
||||
_ftfy_available = importlib.util.find_spec("ftfy") is not None
|
||||
try:
|
||||
_ftfy_version = importlib_metadata.version("ftfy")
|
||||
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_ftfy_available = False
|
||||
|
||||
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
try:
|
||||
# importlib metadata under different name
|
||||
@@ -152,6 +273,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_bs4_available = False
|
||||
|
||||
_torchsde_available = importlib.util.find_spec("torchsde") is not None
|
||||
try:
|
||||
_torchsde_version = importlib_metadata.version("torchsde")
|
||||
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchsde_available = False
|
||||
|
||||
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
|
||||
try:
|
||||
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
|
||||
@@ -159,42 +287,91 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_invisible_watermark_available = False
|
||||
|
||||
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
||||
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
||||
_transformers_available, _transformers_version = _is_package_available("transformers")
|
||||
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
||||
_inflect_available, _inflect_version = _is_package_available("inflect")
|
||||
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
||||
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
||||
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
|
||||
_wandb_available, _wandb_version = _is_package_available("wandb")
|
||||
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
|
||||
_compel_available, _compel_version = _is_package_available("compel")
|
||||
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
|
||||
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
|
||||
_peft_available, _peft_version = _is_package_available("peft")
|
||||
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
|
||||
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
|
||||
_timm_available, _timm_version = _is_package_available("timm")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_imageio_available, _imageio_version = _is_package_available("imageio")
|
||||
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
|
||||
_scipy_available, _scipy_version = _is_package_available("scipy")
|
||||
_librosa_available, _librosa_version = _is_package_available("librosa")
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
|
||||
_xformers_available, _xformers_version = _is_package_available("xformers")
|
||||
_gguf_available, _gguf_version = _is_package_available("gguf")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
|
||||
_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||
if _optimum_quanto_available:
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
try:
|
||||
_peft_version = importlib_metadata.version("peft")
|
||||
logger.debug(f"Successfully imported peft version {_peft_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_peft_available = False
|
||||
|
||||
_torchvision_available = importlib.util.find_spec("torchvision") is not None
|
||||
try:
|
||||
_torchvision_version = importlib_metadata.version("torchvision")
|
||||
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchvision_available = False
|
||||
|
||||
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
|
||||
try:
|
||||
_sentencepiece_version = importlib_metadata.version("sentencepiece")
|
||||
logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_sentencepiece_available = False
|
||||
|
||||
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
|
||||
try:
|
||||
_matplotlib_version = importlib_metadata.version("matplotlib")
|
||||
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_matplotlib_available = False
|
||||
|
||||
_timm_available = importlib.util.find_spec("timm") is not None
|
||||
if _timm_available:
|
||||
try:
|
||||
_timm_version = importlib_metadata.version("timm")
|
||||
logger.info(f"Timm version {_timm_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_timm_available = False
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
|
||||
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
|
||||
try:
|
||||
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
|
||||
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_bitsandbytes_available = False
|
||||
|
||||
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
||||
|
||||
_imageio_available = importlib.util.find_spec("imageio") is not None
|
||||
if _imageio_available:
|
||||
try:
|
||||
_imageio_version = importlib_metadata.version("imageio")
|
||||
logger.debug(f"Successfully imported imageio version {_imageio_version}")
|
||||
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_imageio_available = False
|
||||
|
||||
_is_gguf_available = importlib.util.find_spec("gguf") is not None
|
||||
if _is_gguf_available:
|
||||
try:
|
||||
_gguf_version = importlib_metadata.version("gguf")
|
||||
logger.debug(f"Successfully import gguf version {_gguf_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_gguf_available = False
|
||||
|
||||
|
||||
_is_torchao_available = importlib.util.find_spec("torchao") is not None
|
||||
if _is_torchao_available:
|
||||
try:
|
||||
_torchao_version = importlib_metadata.version("torchao")
|
||||
logger.debug(f"Successfully import torchao version {_torchao_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_torchao_available = False
|
||||
|
||||
|
||||
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||
if _is_optimum_quanto_available:
|
||||
try:
|
||||
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
|
||||
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_optimum_quanto_available = False
|
||||
_is_optimum_quanto_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -318,19 +495,15 @@ def is_imageio_available():
|
||||
|
||||
|
||||
def is_gguf_available():
|
||||
return _gguf_available
|
||||
return _is_gguf_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _torchao_available
|
||||
return _is_torchao_available
|
||||
|
||||
|
||||
def is_optimum_quanto_available():
|
||||
return _optimum_quanto_available
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
return _is_optimum_quanto_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
@@ -690,7 +863,7 @@ def is_gguf_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _gguf_available:
|
||||
if not _is_gguf_available:
|
||||
return False
|
||||
return compare_versions(parse(_gguf_version), operation, version)
|
||||
|
||||
@@ -705,7 +878,7 @@ def is_torchao_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _torchao_available:
|
||||
if not _is_torchao_available:
|
||||
return False
|
||||
return compare_versions(parse(_torchao_version), operation, version)
|
||||
|
||||
@@ -735,7 +908,7 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _optimum_quanto_available:
|
||||
if not _is_optimum_quanto_available:
|
||||
return False
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user