Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul f9ef564a73 Merge branch 'main' into auto-offload-improv 2025-11-26 15:23:01 +05:30
sayakpaul 7dad173147 error early in auto_cpu_offload 2025-11-03 11:35:20 +05:30
4 changed files with 22 additions and 16 deletions
+1 -1
View File
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] apply_faster_cache
## FirstBlockCacheConfig
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
+1 -5
View File
@@ -66,8 +66,4 @@ config = FasterCacheConfig(
tensor_format="BFCHW",
)
pipeline.transformer.enable_cache(config)
```
## FirstBlockCache
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler to implement generically for a wide range of models and has been integrated first for experimental purposes.
```
+1 -3
View File
@@ -41,11 +41,9 @@ class CacheMixin:
Enable caching techniques on the model.
Args:
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
config (`Union[PyramidAttentionBroadcastConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
- [`~hooks.FasterCacheConfig`]
- [`~hooks.FirstBlockCacheConfig`]
Example:
@@ -160,7 +160,10 @@ class AutoOffloadStrategy:
if len(hooks) == 0:
return []
current_module_size = model.get_memory_footprint()
try:
current_module_size = model.get_memory_footprint()
except AttributeError:
raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.")
device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
@@ -703,7 +706,20 @@ class ComponentsManager:
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
# TODO: add a warning if mem_get_info isn't available on `device`.
if device is None:
device = get_device()
if not isinstance(device, torch.device):
device = torch.device(device)
device_type = device.type
device_module = getattr(torch, device_type, torch.cuda)
if not hasattr(device_module, "mem_get_info"):
raise NotImplementedError(
f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}."
)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
@@ -711,11 +727,7 @@ class ComponentsManager:
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
if device is None:
device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
all_hooks = []
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):