Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 16ebfb7754 | |||
| 53977eedef | |||
| fdf1c11e18 | |||
| 6cf941c69f | |||
| 280a0aca4c | |||
| 9297598dff | |||
| 6a0ae75b55 | |||
| 08b8503ffb | |||
| 56ec287e8a | |||
| 8db89e7453 |
@@ -29,16 +29,11 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||
|
||||
# Always use memory-efficient CPU offloading to minimize RAM usage
|
||||
|
||||
_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
@@ -61,6 +56,7 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -72,8 +68,12 @@ class ModuleGroup:
|
||||
self.buffers = buffers
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
@@ -82,125 +82,23 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
# Use the most efficient module-level transfer when possible
|
||||
# This approach mirrors how PyTorch handles full model transfers
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Only onload if some parameters are not on the target device
|
||||
if any(p.device != self.onload_device for p in group_module.parameters()):
|
||||
try:
|
||||
# Try the most efficient approach using _apply
|
||||
if hasattr(group_module, "_apply"):
|
||||
# This is what module.to() uses internally
|
||||
def to_device(t):
|
||||
if t.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
return t.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
return t.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
return t
|
||||
|
||||
# Apply to all tensors without unnecessary copies
|
||||
group_module._apply(to_device)
|
||||
else:
|
||||
# Fallback to direct parameter transfer
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
except Exception as e:
|
||||
# If optimization fails, fall back to direct parameter transfer
|
||||
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle explicit parameters
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle buffers
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
buffer.data = buffer.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
# For CPU offloading
|
||||
if self.offload_device.type == "cpu":
|
||||
# Synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Empty GPU cache before offloading to reduce memory fragmentation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For module groups, use a single, unified approach that is closest to
|
||||
# the behavior of model.to("cpu")
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Check if we need to offload this module
|
||||
if any(p.device.type != "cpu" for p in group_module.parameters()):
|
||||
# Use PyTorch's built-in to() method directly, which preserves
|
||||
# memory mapping when moving to CPU
|
||||
try:
|
||||
# Non-blocking=False for CPU transfers, as it ensures memory is
|
||||
# immediately available and potentially preserves memory mapping
|
||||
group_module.to("cpu", non_blocking=False)
|
||||
except Exception as e:
|
||||
# If there's any error, fall back to parameter-level offloading
|
||||
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
|
||||
for param in group_module.parameters():
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle explicit parameters - move directly to CPU with non-blocking=False
|
||||
# which can preserve memory mapping in some PyTorch versions
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device.type != "cpu":
|
||||
buffer.data = buffer.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Let Python's normal reference counting handle cleanup
|
||||
# We don't force garbage collection to avoid slowing down inference
|
||||
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
else:
|
||||
# For non-CPU offloading, synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# For non-CPU offloading, use the regular approach
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
@@ -210,9 +108,6 @@ class ModuleGroup:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
# After offloading, we can unpin the memory if configured to do so
|
||||
# We'll keep it pinned by default for better performance
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
@@ -234,7 +129,6 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
# Offload to CPU
|
||||
self.group.offload_()
|
||||
return module
|
||||
|
||||
@@ -419,8 +313,7 @@ def apply_group_offloading(
|
||||
If True, offloading and onloading is done with non-blocking data transfer.
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
|
||||
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
|
||||
overlapping computation and data transfer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -451,19 +344,12 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
# We no longer need a pinned group manager as we're not using pinned memory
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module,
|
||||
num_blocks_per_group,
|
||||
offload_device,
|
||||
onload_device,
|
||||
non_blocking,
|
||||
stream,
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
@@ -498,7 +384,12 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# We no longer need a CPU parameter dictionary
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -520,6 +411,7 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -556,6 +448,7 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -590,7 +483,12 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# We no longer need a CPU parameter dictionary
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -605,6 +503,7 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -649,6 +548,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -667,6 +567,7 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
Reference in New Issue
Block a user