Compare commits

..

11 Commits

Author SHA1 Message Date
DN6 d9915a7d65 update 2025-03-12 11:44:40 +05:30
DN6 b7a795dbeb update 2025-03-12 11:40:40 +05:30
DN6 438905d63e update 2025-03-12 11:37:27 +05:30
DN6 904f24de5a update 2025-03-12 11:35:18 +05:30
DN6 e123bbcbc4 memmap 2025-03-12 11:23:14 +05:30
DN6 b3fa8c695d remove cpu param dict 2025-03-12 09:02:04 +05:30
DN6 720be2bac5 update 2025-03-12 08:49:45 +05:30
DN6 e74b782aac update 2025-03-12 08:45:09 +05:30
DN6 d6392b4b49 update 2025-03-12 08:18:19 +05:30
DN6 1475026960 sliding-window 2025-03-11 13:56:39 +05:30
DN6 878eb4ce35 update 2025-03-11 13:21:09 +05:30
2 changed files with 367 additions and 95 deletions
+132 -33
View File
@@ -29,11 +29,16 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
# Always use memory-efficient CPU offloading to minimize RAM usage
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -56,7 +61,6 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -68,12 +72,8 @@ class ModuleGroup:
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
@@ -82,23 +82,125 @@ class ModuleGroup:
self.stream.synchronize()
with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
# Use the most efficient module-level transfer when possible
# This approach mirrors how PyTorch handles full model transfers
if self.modules:
for group_module in self.modules:
# Only onload if some parameters are not on the target device
if any(p.device != self.onload_device for p in group_module.parameters()):
try:
# Try the most efficient approach using _apply
if hasattr(group_module, "_apply"):
# This is what module.to() uses internally
def to_device(t):
if t.device != self.onload_device:
if self.onload_device.type == "cuda":
return t.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
return t.to(self.onload_device,
non_blocking=self.non_blocking)
return t
# Apply to all tensors without unnecessary copies
group_module._apply(to_device)
else:
# Fallback to direct parameter transfer
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
except Exception as e:
# If optimization fails, fall back to direct parameter transfer
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle explicit parameters
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if buffer.device != self.onload_device:
if self.onload_device.type == "cuda":
buffer.data = buffer.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
buffer.data = buffer.data.to(self.onload_device,
non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
# For CPU offloading
if self.offload_device.type == "cpu":
# Synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# Empty GPU cache before offloading to reduce memory fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
# For module groups, use a single, unified approach that is closest to
# the behavior of model.to("cpu")
if self.modules:
for group_module in self.modules:
# Check if we need to offload this module
if any(p.device.type != "cpu" for p in group_module.parameters()):
# Use PyTorch's built-in to() method directly, which preserves
# memory mapping when moving to CPU
try:
# Non-blocking=False for CPU transfers, as it ensures memory is
# immediately available and potentially preserves memory mapping
group_module.to("cpu", non_blocking=False)
except Exception as e:
# If there's any error, fall back to parameter-level offloading
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
for param in group_module.parameters():
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle explicit parameters - move directly to CPU with non-blocking=False
# which can preserve memory mapping in some PyTorch versions
if self.parameters is not None:
for param in self.parameters:
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
if buffer.device.type != "cpu":
buffer.data = buffer.data.to("cpu", non_blocking=False)
# Let Python's normal reference counting handle cleanup
# We don't force garbage collection to avoid slowing down inference
else:
# For non-CPU offloading, synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# For non-CPU offloading, use the regular approach
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
@@ -108,6 +210,9 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
# After offloading, we can unpin the memory if configured to do so
# We'll keep it pinned by default for better performance
class GroupOffloadingHook(ModelHook):
r"""
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
# Offload to CPU
self.group.offload_()
return module
@@ -313,7 +419,8 @@ def apply_group_offloading(
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
Example:
```python
@@ -344,12 +451,19 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
# We no longer need a pinned group manager as we're not using pinned memory
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module,
num_blocks_per_group,
offload_device,
onload_device,
non_blocking,
stream,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=stream is None,
)
matched_module_groups.append(group)
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
+235 -62
View File
@@ -25,6 +25,7 @@ from types import ModuleType
from typing import Any, Union
from huggingface_hub.utils import is_jinja_available # noqa: F401
from packaging import version
from packaging.version import Version, parse
from . import logging
@@ -51,30 +52,36 @@ DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
def _is_package_available(pkg_name: str):
pkg_exists = importlib.util.find_spec(pkg_name) is not None
pkg_version = "N/A"
if pkg_exists:
try:
pkg_version = importlib_metadata.version(pkg_name)
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
except (ImportError, importlib_metadata.PackageNotFoundError):
pkg_exists = False
return pkg_exists, pkg_version
_torch_version = "N/A"
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available, _torch_version = _is_package_available("torch")
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
try:
_torch_xla_version = importlib_metadata.version("torch_xla")
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
except ImportError:
_torch_xla_available = False
# check whether torch_npu is available
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
if _torch_npu_available:
try:
_torch_npu_version = importlib_metadata.version("torch_npu")
logger.info(f"torch_npu version {_torch_npu_version} available.")
except ImportError:
_torch_npu_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
@@ -90,12 +97,47 @@ else:
_flax_available = False
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
_safetensors_available = importlib.util.find_spec("safetensors") is not None
if _safetensors_available:
try:
_safetensors_version = importlib_metadata.version("safetensors")
logger.info(f"Safetensors version {_safetensors_version} available.")
except importlib_metadata.PackageNotFoundError:
_safetensors_available = False
else:
logger.info("Disabling Safetensors because USE_TF is set")
_safetensors_available = False
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
try:
_hf_hub_version = importlib_metadata.version("huggingface_hub")
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
except importlib_metadata.PackageNotFoundError:
_hf_hub_available = False
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
_inflect_version = importlib_metadata.version("inflect")
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
_inflect_available = False
_unidecode_available = importlib.util.find_spec("unidecode") is not None
try:
_unidecode_version = importlib_metadata.version("unidecode")
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
@@ -144,6 +186,85 @@ try:
except importlib_metadata.PackageNotFoundError:
_opencv_available = False
_scipy_available = importlib.util.find_spec("scipy") is not None
try:
_scipy_version = importlib_metadata.version("scipy")
logger.debug(f"Successfully imported scipy version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
_scipy_available = False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
_librosa_version = importlib_metadata.version("librosa")
logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
_accelerate_version = importlib_metadata.version("accelerate")
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False
_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
_torch_version = importlib_metadata.version("torch")
if version.Version(_torch_version) < version.Version("1.12"):
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
try:
_k_diffusion_version = importlib_metadata.version("k_diffusion")
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
except importlib_metadata.PackageNotFoundError:
_k_diffusion_available = False
_note_seq_available = importlib.util.find_spec("note_seq") is not None
try:
_note_seq_version = importlib_metadata.version("note_seq")
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
except importlib_metadata.PackageNotFoundError:
_note_seq_available = False
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
_wandb_version = importlib_metadata.version("wandb")
logger.debug(f"Successfully imported wandb version {_wandb_version }")
except importlib_metadata.PackageNotFoundError:
_wandb_available = False
_tensorboard_available = importlib.util.find_spec("tensorboard")
try:
_tensorboard_version = importlib_metadata.version("tensorboard")
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
except importlib_metadata.PackageNotFoundError:
_tensorboard_available = False
_compel_available = importlib.util.find_spec("compel")
try:
_compel_version = importlib_metadata.version("compel")
logger.debug(f"Successfully imported compel version {_compel_version}")
except importlib_metadata.PackageNotFoundError:
_compel_available = False
_ftfy_available = importlib.util.find_spec("ftfy") is not None
try:
_ftfy_version = importlib_metadata.version("ftfy")
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
except importlib_metadata.PackageNotFoundError:
_ftfy_available = False
_bs4_available = importlib.util.find_spec("bs4") is not None
try:
# importlib metadata under different name
@@ -152,6 +273,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_bs4_available = False
_torchsde_available = importlib.util.find_spec("torchsde") is not None
try:
_torchsde_version = importlib_metadata.version("torchsde")
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
except importlib_metadata.PackageNotFoundError:
_torchsde_available = False
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
@@ -159,42 +287,91 @@ try:
except importlib_metadata.PackageNotFoundError:
_invisible_watermark_available = False
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
_wandb_available, _wandb_version = _is_package_available("wandb")
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
_compel_available, _compel_version = _is_package_available("compel")
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
_peft_available, _peft_version = _is_package_available("peft")
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
_timm_available, _timm_version = _is_package_available("timm")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_imageio_available, _imageio_version = _is_package_available("imageio")
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
_scipy_available, _scipy_version = _is_package_available("scipy")
_librosa_available, _librosa_version = _is_package_available("librosa")
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
_xformers_available, _xformers_version = _is_package_available("xformers")
_gguf_available, _gguf_version = _is_package_available("gguf")
_torchao_available, _torchao_version = _is_package_available("torchao")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_torchao_available, _torchao_version = _is_package_available("torchao")
_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
if _optimum_quanto_available:
_peft_available = importlib.util.find_spec("peft") is not None
try:
_peft_version = importlib_metadata.version("peft")
logger.debug(f"Successfully imported peft version {_peft_version}")
except importlib_metadata.PackageNotFoundError:
_peft_available = False
_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
_torchvision_version = importlib_metadata.version("torchvision")
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
except importlib_metadata.PackageNotFoundError:
_torchvision_available = False
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
try:
_sentencepiece_version = importlib_metadata.version("sentencepiece")
logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
except importlib_metadata.PackageNotFoundError:
_sentencepiece_available = False
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
try:
_matplotlib_version = importlib_metadata.version("matplotlib")
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
except importlib_metadata.PackageNotFoundError:
_matplotlib_available = False
_timm_available = importlib.util.find_spec("timm") is not None
if _timm_available:
try:
_timm_version = importlib_metadata.version("timm")
logger.info(f"Timm version {_timm_version} available.")
except importlib_metadata.PackageNotFoundError:
_timm_available = False
def is_timm_available():
return _timm_available
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
try:
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
except importlib_metadata.PackageNotFoundError:
_bitsandbytes_available = False
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
_imageio_available = importlib.util.find_spec("imageio") is not None
if _imageio_available:
try:
_imageio_version = importlib_metadata.version("imageio")
logger.debug(f"Successfully imported imageio version {_imageio_version}")
except importlib_metadata.PackageNotFoundError:
_imageio_available = False
_is_gguf_available = importlib.util.find_spec("gguf") is not None
if _is_gguf_available:
try:
_gguf_version = importlib_metadata.version("gguf")
logger.debug(f"Successfully import gguf version {_gguf_version}")
except importlib_metadata.PackageNotFoundError:
_is_gguf_available = False
_is_torchao_available = importlib.util.find_spec("torchao") is not None
if _is_torchao_available:
try:
_torchao_version = importlib_metadata.version("torchao")
logger.debug(f"Successfully import torchao version {_torchao_version}")
except importlib_metadata.PackageNotFoundError:
_is_torchao_available = False
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
if _is_optimum_quanto_available:
try:
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
except importlib_metadata.PackageNotFoundError:
_optimum_quanto_available = False
_is_optimum_quanto_available = False
def is_torch_available():
@@ -318,19 +495,15 @@ def is_imageio_available():
def is_gguf_available():
return _gguf_available
return _is_gguf_available
def is_torchao_available():
return _torchao_available
return _is_torchao_available
def is_optimum_quanto_available():
return _optimum_quanto_available
def is_timm_available():
return _timm_available
return _is_optimum_quanto_available
# docstyle-ignore
@@ -690,7 +863,7 @@ def is_gguf_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _gguf_available:
if not _is_gguf_available:
return False
return compare_versions(parse(_gguf_version), operation, version)
@@ -705,7 +878,7 @@ def is_torchao_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _torchao_available:
if not _is_torchao_available:
return False
return compare_versions(parse(_torchao_version), operation, version)
@@ -735,7 +908,7 @@ def is_optimum_quanto_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _optimum_quanto_available:
if not _is_optimum_quanto_available:
return False
return compare_versions(parse(_optimum_quanto_version), operation, version)