mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5481434][feat] cherry-pick fix to reuse pytorch memory segments occupied by cudagraph (#7747)
Signed-off-by: Hui Gao <huig@nvidia.com>
This commit is contained in:
parent
fc4e6d3702
commit
a6370fd143
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -365,6 +366,10 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
3. moe_finalize_scale_op: finalize the scale of the output tensor.
|
||||
"""
|
||||
|
||||
# To reuse pytorch memory segments allocated during graph capture.
|
||||
allocated_buffer_in_graph_pool: dict[str, list[torch.Tensor]] = {}
|
||||
allocated_buffer_in_runtime: dict[str, torch.Tensor] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -410,28 +415,102 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
)
|
||||
|
||||
def get_workspace(self, m_max: int, group_size: int):
|
||||
|
||||
def select_buffer_with_more_elements(
|
||||
graph_buffer: Optional[torch.Tensor],
|
||||
runtime_buffer: Optional[torch.Tensor],
|
||||
is_capturing: bool = False
|
||||
) -> tuple[Optional[torch.Tensor], bool]:
|
||||
if is_capturing and graph_buffer is not None:
|
||||
return graph_buffer, True
|
||||
|
||||
if is_capturing == False and runtime_buffer is not None:
|
||||
return runtime_buffer, False
|
||||
|
||||
if graph_buffer is None:
|
||||
return runtime_buffer, False
|
||||
|
||||
if runtime_buffer is None:
|
||||
return graph_buffer, True
|
||||
|
||||
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
|
||||
cache_name: str) -> torch.Tensor:
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
|
||||
numel_like = math.prod(tensor_shape)
|
||||
runtime_buffer = None
|
||||
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
|
||||
buffer = DeepGemmFusedMoE.allocated_buffer_in_runtime[
|
||||
cache_name]
|
||||
numel_buffer = buffer.numel()
|
||||
runtime_buffer = buffer if numel_buffer >= numel_like else None
|
||||
|
||||
graph_buffer = None
|
||||
# Safely get the list of candidates. Defaults to an empty list if key is missing.
|
||||
candidate_buffers = DeepGemmFusedMoE.allocated_buffer_in_graph_pool.get(
|
||||
cache_name, [])
|
||||
for buffer in candidate_buffers:
|
||||
numel_buffer = buffer.numel()
|
||||
# buffer just needs to be large enough.
|
||||
if numel_buffer >= numel_like:
|
||||
graph_buffer = buffer
|
||||
break
|
||||
|
||||
if capture_graph and graph_buffer is not None:
|
||||
return graph_buffer[0:numel_like].view(tensor_shape)
|
||||
else:
|
||||
buffer, use_graph = select_buffer_with_more_elements(
|
||||
graph_buffer,
|
||||
runtime_buffer,
|
||||
is_capturing=capture_graph)
|
||||
if buffer is not None:
|
||||
if not use_graph and capture_graph:
|
||||
# move the buffer into graph buffers since it's running in graph capturing mode.
|
||||
DeepGemmFusedMoE.allocated_buffer_in_runtime.pop(
|
||||
cache_name, None)
|
||||
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
|
||||
cache_name, []).append(buffer)
|
||||
|
||||
return buffer[0:numel_like].view(tensor_shape)
|
||||
|
||||
# Reach here, no buffer is found. Then, we will use a new buffer to replace the small one. Release the memory first.
|
||||
if cache_name in DeepGemmFusedMoE.allocated_buffer_in_runtime:
|
||||
del DeepGemmFusedMoE.allocated_buffer_in_runtime[cache_name]
|
||||
|
||||
# If we get here, no suitable buffer was found in the cache. Create a new one.
|
||||
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
|
||||
if DeepGemmFusedMoE.allocated_buffer_in_graph_pool is not None:
|
||||
if capture_graph:
|
||||
DeepGemmFusedMoE.allocated_buffer_in_graph_pool.setdefault(
|
||||
cache_name, []).append(new_buffer)
|
||||
else:
|
||||
DeepGemmFusedMoE.allocated_buffer_in_runtime[
|
||||
cache_name] = new_buffer
|
||||
return new_buffer
|
||||
|
||||
hidden_size = self.hidden_size
|
||||
intermediate_size = self.intermediate_size_per_partition
|
||||
num_experts = self.expert_size_per_partition
|
||||
|
||||
# create workspace
|
||||
fp8_dim = max(hidden_size, intermediate_size)
|
||||
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device='cuda')
|
||||
workspace_1 = torch.empty(
|
||||
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
|
||||
workspace_0 = get_empty((num_experts * m_max * fp8_dim, ),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
cache_name='workspace_0')
|
||||
workspace_1 = get_empty(
|
||||
(num_experts * m_max * max(intermediate_size * 2, hidden_size), ),
|
||||
dtype=torch.bfloat16,
|
||||
device='cuda')
|
||||
cache_name='workspace_1')
|
||||
|
||||
# create workspace for scaling factors
|
||||
m_padded = fp8_utils.align(m_max, 4)
|
||||
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
|
||||
scale_k_padded = fp8_utils.align(scale_k, 4)
|
||||
workspace_sf = torch.empty(
|
||||
(num_experts * (scale_k_padded // 4) * m_padded),
|
||||
|
||||
workspace_sf = get_empty(
|
||||
(num_experts * (scale_k_padded // 4) * m_padded, ),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
cache_name='workspace_sf')
|
||||
|
||||
workspace = {
|
||||
"workspace_0": workspace_0,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user