[https://nvbugs/5481434][feat] cherry-pick fix to reuse pytorch memory segments occupied by cudagraph (#7747)

Signed-off-by: Hui Gao <huig@nvidia.com>
This commit is contained in:
HuiGao-NV 2025-09-19 10:25:21 +08:00 committed by GitHub
parent fc4e6d3702
commit a6370fd143
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import math
from typing import Dict, List, Optional, Union
import torch
@ -365,6 +366,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
3. moe_finalize_scale_op: finalize the scale of the output tensor.
"""
# To reuse pytorch memory segments allocated during graph capture.
allocated_buffer_in_graph_pool: dict[str, list[torch.Tensor]] = {}
allocated_buffer_in_runtime: dict[str, torch.Tensor] = {}
def __init__(
self,
*,
@ -410,28 +415,102 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
)
def get_workspace(self, m_max: int, group_size: int):
def select_buffer_with_more_elements(
graph_buffer: Optional[torch.Tensor],
runtime_buffer: Optional[torch.Tensor],
is_capturing: bool = False
) -> tuple[Optional[torch.Tensor], bool]:
if is_capturing and graph_buffer is not None:
return graph_buffer, True
if is_capturing == False and runtime_buffer is not None:
return runtime_buffer, False
if graph_buffer is None:
return runtime_buffer, False
if runtime_buffer is None:
return graph_buffer, True
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
capture_graph = torch.cuda.is_current_stream_capturing()
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
numel_like = math.prod(tensor_shape)
runtime_buffer = None
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
buffer = DeepGemmFusedMoE.allocated_buffer_in_runtime[
cache_name]
numel_buffer = buffer.numel()
runtime_buffer = buffer if numel_buffer >= numel_like else None
graph_buffer = None
# Safely get the list of candidates. Defaults to an empty list if key is missing.
candidate_buffers = DeepGemmFusedMoE.allocated_buffer_in_graph_pool.get(
cache_name, [])
for buffer in candidate_buffers:
numel_buffer = buffer.numel()
# buffer just needs to be large enough.
if numel_buffer >= numel_like:
graph_buffer = buffer
break
if capture_graph and graph_buffer is not None:
return graph_buffer[0:numel_like].view(tensor_shape)
else:
buffer, use_graph = select_buffer_with_more_elements(
graph_buffer,
runtime_buffer,
is_capturing=capture_graph)
if buffer is not None:
if not use_graph and capture_graph:
# move the buffer into graph buffers since it's running in graph capturing mode.
DeepGemmFusedMoE.allocated_buffer_in_runtime.pop(
cache_name, None)
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
cache_name, []).append(buffer)
return buffer[0:numel_like].view(tensor_shape)
# Reach here, no buffer is found. Then, we will use a new buffer to replace the small one. Release the memory first.
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
del DeepGemmFusedMoE.allocated_buffer_in_runtime[cache_name]
# If we get here, no suitable buffer was found in the cache. Create a new one.
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
if capture_graph:
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
cache_name, []).append(new_buffer)
else:
DeepGemmFusedMoE.allocated_buffer_in_runtime[
cache_name] = new_buffer
return new_buffer
hidden_size = self.hidden_size
intermediate_size = self.intermediate_size_per_partition
num_experts = self.expert_size_per_partition
# create workspace
fp8_dim = max(hidden_size, intermediate_size)
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
dtype=torch.float8_e4m3fn,
device='cuda')
workspace_1 = torch.empty(
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
workspace_0 = get_empty((num_experts * m_max * fp8_dim, ),
dtype=torch.float8_e4m3fn,
cache_name='workspace_0')
workspace_1 = get_empty(
(num_experts * m_max * max(intermediate_size * 2, hidden_size), ),
dtype=torch.bfloat16,
device='cuda')
cache_name='workspace_1')
# create workspace for scaling factors
m_padded = fp8_utils.align(m_max, 4)
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
scale_k_padded = fp8_utils.align(scale_k, 4)
workspace_sf = torch.empty(
(num_experts * (scale_k_padded // 4) * m_padded),
workspace_sf = get_empty(
(num_experts * (scale_k_padded // 4) * m_padded, ),
dtype=torch.int32,
device='cuda')
cache_name='workspace_sf')
workspace = {
"workspace_0": workspace_0,