[None][fix] Pre-allocate workspaces for DeepGEMM MoE to avoid frequent cudaFree/cudaMalloc (#6811)

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2025-08-13 10:27:57 +08:00 committed by GitHub
parent 47806f09d9
commit 1bbc0e323b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 256 additions and 47 deletions

View File

@ -110,13 +110,11 @@ class CutlassFusedMoE(MoE):
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.use_dp:
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
if self.moe_max_num_tokens < max_num_tokens:
if self.moe_max_num_tokens < moe_max_num_tokens:
self.aux_stream = aux_stream_dict[
AuxStreamType.
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

View File

@ -11,7 +11,7 @@ from tensorrt_llm._utils import nvtx_range
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import AuxStreamType, Fp4QuantizedTensor
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8(
def masked_index_copy_group_quant_fp8(
output: torch.Tensor,
output_s: torch.Tensor,
input: torch.Tensor,
start_offsets: torch.Tensor,
row_indices: torch.Tensor,
@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8(
col_size = output.shape[1]
dim_size = output.shape[2]
# create padded output_s
alignment = 4
scale_dim = (dim_size + group_size - 1) // group_size
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
padded_col_size = (col_size + alignment - 1) // alignment * alignment
output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size),
dtype=torch.int32,
device='cuda')
# get block/grid/stage/warp
num_groups = (dim_size + group_size - 1) // group_size
@ -247,6 +244,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
@nvtx_range("[DG]")
def deepgemm_fp8_group_blockwise_gemm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
sfa: torch.Tensor,
@ -254,10 +252,6 @@ def deepgemm_fp8_group_blockwise_gemm(
masked_m: torch.Tensor,
expected_m: int,
) -> torch.Tensor:
d = torch.empty((a.shape[0], a.shape[1], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
assert a.stride(-1) == 1
assert b.stride(-1) == 1
@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm(
masked_m,
expected_m,
disable_ue8m0_cast=True)
return d
return
def set_strides(workspace: torch.Tensor, g: int, m: int, k: int):
workspace = workspace[0:g * m * k]
workspace = workspace.as_strided(
size=(g, m, k),
stride=(m * k, k, 1),
)
return workspace
class DeepGemmFusedMoE(CutlassFusedMoE):
@ -327,6 +330,18 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
apply_router_weight_on_input: bool = False,
layer_idx: Optional[int] = None,
):
if model_config.moe_max_num_tokens is None:
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
# The default moe_max_num_tokens is calculated from the following formula:
# max_isl = 8196, max_batch_size = 1024, mtp = 0
# max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
# moe_max_num_tokens = max_num_tokens * 2 = 18688
# It can avoid OOM for 8k/1k cases.
default_moe_max_num_tokens = 18688
if moe_max_num_tokens > default_moe_max_num_tokens:
model_config._frozen = False
model_config.moe_max_num_tokens = default_moe_max_num_tokens
model_config._frozen = True
super().__init__(
routing_method=routing_method,
@ -342,6 +357,37 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
layer_idx=layer_idx,
)
def get_workspace(self, m_max: int, group_size: int):
hidden_size = self.hidden_size
intermediate_size = self.intermediate_size
num_experts = self.expert_size_per_partition
# create workspace
fp8_dim = max(hidden_size, intermediate_size)
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
dtype=torch.float8_e4m3fn,
device='cuda')
workspace_1 = torch.empty(
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
dtype=torch.bfloat16,
device='cuda')
# create workspace for scaling factors
m_padded = fp8_utils.align(m_max, 4)
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
scale_k_padded = fp8_utils.align(scale_k, 4)
workspace_sf = torch.empty(
(num_experts * (scale_k_padded // 4) * m_padded),
dtype=torch.int32,
device='cuda')
workspace = {
"workspace_0": workspace_0,
"workspace_1": workspace_1,
"workspace_sf": workspace_sf,
}
return workspace
def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
@ -362,6 +408,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
workspace: Optional[dict] = None,
) -> torch.Tensor:
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
@ -437,22 +484,38 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
masked_m, token_to_expert_map = preprocess_after_permute(
expert_first_token_offset_tensor, permuted_data_tensor)
m_max = (x.shape[0] + 127) // 128 * 128
expected_m = (token_selected_experts.numel() +
self.expert_size_per_partition -
1) // self.expert_size_per_partition
act_input_fp8 = torch.empty(
(self.expert_size_per_partition, m_max, self.hidden_size),
dtype=torch.float8_e4m3fn,
device='cuda')
# padding and quantization
m_max = fp8_utils.align(x.shape[0], 128)
act_input_fp8 = set_strides(workspace["workspace_0"],
self.expert_size_per_partition, m_max,
self.hidden_size)
m_padded = fp8_utils.align(m_max, 4)
scale_k = fp8_utils.ceil_div(self.hidden_size, 128)
scale_k_padded = fp8_utils.align(scale_k, 4)
act_input_sf = set_strides(workspace["workspace_sf"],
self.expert_size_per_partition,
scale_k_padded // 4, m_padded)
act_input_sf = masked_index_copy_group_quant_fp8(
act_input_fp8,
act_input_sf,
permuted_data_tensor,
expert_first_token_offset_tensor,
token_to_expert_map,
group_size=128)
h1 = deepgemm_fp8_group_blockwise_gemm(
# grouped gemm 1
h1 = set_strides(workspace["workspace_1"],
self.expert_size_per_partition, m_max,
self.intermediate_size * 2)
deepgemm_fp8_group_blockwise_gemm(
d=h1,
a=act_input_fp8,
b=self.w3_w1_weight,
sfa=act_input_sf,
@ -460,9 +523,33 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
masked_m=masked_m,
expected_m=expected_m,
)
act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True)
h3 = deepgemm_fp8_group_blockwise_gemm(
# activation and quantization
act_input_fp8 = set_strides(workspace["workspace_0"],
self.expert_size_per_partition, m_max,
self.intermediate_size)
scale_k = fp8_utils.ceil_div(self.intermediate_size, 128)
scale_k_padded = fp8_utils.align(scale_k, 4)
act_input_sf = set_strides(workspace["workspace_sf"],
self.expert_size_per_partition,
scale_k_padded // 4, m_padded)
act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
output=act_input_fp8,
output_scale=act_input_sf,
input=h1,
quant_group_size=128,
masked_m=masked_m,
scale_ue8m0=True)
# grouped gemm 2
h3 = set_strides(workspace["workspace_1"],
self.expert_size_per_partition, m_max,
self.hidden_size)
deepgemm_fp8_group_blockwise_gemm(
d=h3,
a=act_input_fp8,
b=self.w2_weight,
sfa=act_input_sf,
@ -471,6 +558,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
expected_m=expected_m,
)
# gather and finalize
triton_masked_index_gather(permuted_data_tensor, h3,
expert_first_token_offset_tensor,
token_to_expert_map)
@ -495,3 +583,137 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
)
return final_hidden_states
def forward(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
do_finalize: bool = True, # used by other MoE backends
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
) -> torch.Tensor:
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
if self.use_dp and self.parallel_size > 1:
assert all_rank_num_tokens is not None
assert use_dp_padding is not None
num_rows = sum(all_rank_num_tokens)
else:
num_rows = x.shape[0]
# In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
# Because we will use two streams in chunked moe and preallocate two workspaces.
num_chunks = 1
if num_rows > self.moe_max_num_tokens * 2:
num_chunks = (num_rows + self.moe_max_num_tokens -
1) // self.moe_max_num_tokens
if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens
] * len(all_rank_num_tokens)
else:
all_rank_num_tokens_padded = all_rank_num_tokens
if num_chunks == 1:
# create workspace
num_rows = x.shape[0]
if self.use_dp:
num_rows = sum(all_rank_num_tokens_padded)
m_max = fp8_utils.align(num_rows, 128)
workspace = self.get_workspace(m_max, 128)
outputs = self.forward_chunk(
x,
router_logits,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding,
workspace=workspace)
outputs = self.reducescatter_or_allreduce(
outputs,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding)
else:
if self.use_dp:
all_rank_chunk_size_list = [
self.split_chunk(val, num_chunks)
for val in all_rank_num_tokens_padded
]
all_rank_num_tokens_list = [[
val[idx_chunk] for val in all_rank_chunk_size_list
] for idx_chunk in range(num_chunks)]
chunk_size_list = all_rank_chunk_size_list[self.rank]
else:
all_rank_num_tokens_list = [None] * num_chunks
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
# create workspace
chunk_size_0 = sum(all_rank_num_tokens_list[0]
) if self.use_dp else chunk_size_list[0]
chunk_size_1 = sum(all_rank_num_tokens_list[1]
) if self.use_dp else chunk_size_list[1]
workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128),
128)
workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128),
128)
x_list = x.split(chunk_size_list)
router_logits_list = router_logits.split(chunk_size_list)
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
def _forward_chunk(x_, router_logits_, idx, workspace):
return self.forward_chunk(
x_,
router_logits_,
all_rank_num_tokens=all_rank_num_tokens_list[idx]
if self.use_dp else None,
use_dp_padding=use_dp_padding,
workspace=workspace)
def _reducescatter_or_allreduce(x_, idx):
return self.reducescatter_or_allreduce(
x_,
all_rank_num_tokens=all_rank_num_tokens_list[idx],
use_dp_padding=use_dp_padding)
outputs_list = []
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
for idx_chunk, (x, router_logits) in enumerate(
zip(x_list, router_logits_list)):
if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
outputs = _forward_chunk(x, router_logits, idx_chunk,
workspace_0)
if idx_chunk > 0:
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], idx_chunk - 1)
else:
outputs = _forward_chunk(x, router_logits, idx_chunk,
workspace_1)
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], idx_chunk - 1)
outputs_list.append(outputs)
if num_chunks % 2 == 0:
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], -1)
else:
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], -1)
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.MoeChunkingOverlap].record()
self.event_dict[EventType.MoeChunkingOverlap].wait()
outputs = torch.cat(outputs_list)
if self.use_dp and self.parallel_size > 1:
rank = self.mapping.tp_rank
outputs = outputs[:all_rank_num_tokens[rank]]
return outputs

View File

@ -81,13 +81,9 @@ class VanillaMoE(nn.ModuleList):
self.num_experts)
self.expert_size_per_partition = self.expert_end - self.expert_start
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.use_dp:
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
if model_config.moe_max_num_tokens
is not None else max_num_tokens)
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
self._weights_created = False
if not model_config.skip_create_weights_in_init:

View File

@ -150,12 +150,11 @@ class WideEPMoE(MoE):
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
if self.moe_max_num_tokens < max_num_tokens:
if self.moe_max_num_tokens < moe_max_num_tokens:
self.aux_stream = aux_stream_dict[
AuxStreamType.
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

View File

@ -372,6 +372,10 @@ class Mapping(object):
def local_rank(self):
return self.rank % self.gpus_per_node
@property
def dp_size(self):
return self.tp_size if self.enable_attention_dp else 1
def has_cp(self):
return self.cp_size > 1

View File

@ -302,6 +302,8 @@ def _silu_and_mul_post_quant_kernel(
def silu_and_mul_masked_post_quant_fwd(
output: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
@ -328,18 +330,6 @@ def silu_and_mul_masked_post_quant_fwd(
g, m, k = input.shape
k = k // 2
# Create output
output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda")
# Create output scale
alignment = 4
scale_k = ceil_div(k, quant_group_size)
m_padded = align(m, alignment)
scale_k_padded = align(scale_k, alignment)
output_scale = torch.zeros((g, scale_k_padded // 4, m_padded),
dtype=torch.int32,
device='cuda')
# Get block/grid/stage/warp
expert_num = len(masked_m)
@ -382,7 +372,7 @@ def silu_and_mul_masked_post_quant_fwd(
g,
tma_stride_check=True,
)
return output, output_scale
return output_scale
@triton.jit