mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Pre-allocate workspaces for DeepGEMM MoE to avoid frequent cudaFree/cudaMalloc (#6811)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
47806f09d9
commit
1bbc0e323b
@ -110,13 +110,11 @@ class CutlassFusedMoE(MoE):
|
||||
assert len(
|
||||
self.initial_local_expert_ids) == self.expert_size_per_partition
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
if self.use_dp:
|
||||
max_num_tokens *= model_config.mapping.world_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
|
||||
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
|
||||
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
|
||||
if self.moe_max_num_tokens < max_num_tokens:
|
||||
if self.moe_max_num_tokens < moe_max_num_tokens:
|
||||
self.aux_stream = aux_stream_dict[
|
||||
AuxStreamType.
|
||||
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
|
||||
|
||||
@ -11,7 +11,7 @@ from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
|
||||
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
|
||||
@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8(
|
||||
|
||||
def masked_index_copy_group_quant_fp8(
|
||||
output: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
start_offsets: torch.Tensor,
|
||||
row_indices: torch.Tensor,
|
||||
@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8(
|
||||
col_size = output.shape[1]
|
||||
dim_size = output.shape[2]
|
||||
|
||||
# create padded output_s
|
||||
alignment = 4
|
||||
scale_dim = (dim_size + group_size - 1) // group_size
|
||||
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
|
||||
padded_col_size = (col_size + alignment - 1) // alignment * alignment
|
||||
output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
# get block/grid/stage/warp
|
||||
num_groups = (dim_size + group_size - 1) // group_size
|
||||
@ -247,6 +244,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
|
||||
|
||||
@nvtx_range("[DG]")
|
||||
def deepgemm_fp8_group_blockwise_gemm(
|
||||
d: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
sfa: torch.Tensor,
|
||||
@ -254,10 +252,6 @@ def deepgemm_fp8_group_blockwise_gemm(
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
) -> torch.Tensor:
|
||||
d = torch.empty((a.shape[0], a.shape[1], b.shape[1]),
|
||||
device=b.device,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
assert a.stride(-1) == 1
|
||||
assert b.stride(-1) == 1
|
||||
@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm(
|
||||
masked_m,
|
||||
expected_m,
|
||||
disable_ue8m0_cast=True)
|
||||
return d
|
||||
return
|
||||
|
||||
|
||||
def set_strides(workspace: torch.Tensor, g: int, m: int, k: int):
|
||||
workspace = workspace[0:g * m * k]
|
||||
workspace = workspace.as_strided(
|
||||
size=(g, m, k),
|
||||
stride=(m * k, k, 1),
|
||||
)
|
||||
return workspace
|
||||
|
||||
|
||||
class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
@ -327,6 +330,18 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
if model_config.moe_max_num_tokens is None:
|
||||
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
|
||||
# The default moe_max_num_tokens is calculated from the following formula:
|
||||
# max_isl = 8196, max_batch_size = 1024, mtp = 0
|
||||
# max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
|
||||
# moe_max_num_tokens = max_num_tokens * 2 = 18688
|
||||
# It can avoid OOM for 8k/1k cases.
|
||||
default_moe_max_num_tokens = 18688
|
||||
if moe_max_num_tokens > default_moe_max_num_tokens:
|
||||
model_config._frozen = False
|
||||
model_config.moe_max_num_tokens = default_moe_max_num_tokens
|
||||
model_config._frozen = True
|
||||
|
||||
super().__init__(
|
||||
routing_method=routing_method,
|
||||
@ -342,6 +357,37 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
def get_workspace(self, m_max: int, group_size: int):
|
||||
hidden_size = self.hidden_size
|
||||
intermediate_size = self.intermediate_size
|
||||
num_experts = self.expert_size_per_partition
|
||||
|
||||
# create workspace
|
||||
fp8_dim = max(hidden_size, intermediate_size)
|
||||
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device='cuda')
|
||||
workspace_1 = torch.empty(
|
||||
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
|
||||
dtype=torch.bfloat16,
|
||||
device='cuda')
|
||||
|
||||
# create workspace for scaling factors
|
||||
m_padded = fp8_utils.align(m_max, 4)
|
||||
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
|
||||
scale_k_padded = fp8_utils.align(scale_k, 4)
|
||||
workspace_sf = torch.empty(
|
||||
(num_experts * (scale_k_padded // 4) * m_padded),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
workspace = {
|
||||
"workspace_0": workspace_0,
|
||||
"workspace_1": workspace_1,
|
||||
"workspace_sf": workspace_sf,
|
||||
}
|
||||
return workspace
|
||||
|
||||
def _get_quant_method(self):
|
||||
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
@ -362,6 +408,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
workspace: Optional[dict] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
@ -437,22 +484,38 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
masked_m, token_to_expert_map = preprocess_after_permute(
|
||||
expert_first_token_offset_tensor, permuted_data_tensor)
|
||||
|
||||
m_max = (x.shape[0] + 127) // 128 * 128
|
||||
expected_m = (token_selected_experts.numel() +
|
||||
self.expert_size_per_partition -
|
||||
1) // self.expert_size_per_partition
|
||||
act_input_fp8 = torch.empty(
|
||||
(self.expert_size_per_partition, m_max, self.hidden_size),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device='cuda')
|
||||
|
||||
# padding and quantization
|
||||
m_max = fp8_utils.align(x.shape[0], 128)
|
||||
act_input_fp8 = set_strides(workspace["workspace_0"],
|
||||
self.expert_size_per_partition, m_max,
|
||||
self.hidden_size)
|
||||
|
||||
m_padded = fp8_utils.align(m_max, 4)
|
||||
scale_k = fp8_utils.ceil_div(self.hidden_size, 128)
|
||||
scale_k_padded = fp8_utils.align(scale_k, 4)
|
||||
act_input_sf = set_strides(workspace["workspace_sf"],
|
||||
self.expert_size_per_partition,
|
||||
scale_k_padded // 4, m_padded)
|
||||
|
||||
act_input_sf = masked_index_copy_group_quant_fp8(
|
||||
act_input_fp8,
|
||||
act_input_sf,
|
||||
permuted_data_tensor,
|
||||
expert_first_token_offset_tensor,
|
||||
token_to_expert_map,
|
||||
group_size=128)
|
||||
|
||||
h1 = deepgemm_fp8_group_blockwise_gemm(
|
||||
# grouped gemm 1
|
||||
h1 = set_strides(workspace["workspace_1"],
|
||||
self.expert_size_per_partition, m_max,
|
||||
self.intermediate_size * 2)
|
||||
|
||||
deepgemm_fp8_group_blockwise_gemm(
|
||||
d=h1,
|
||||
a=act_input_fp8,
|
||||
b=self.w3_w1_weight,
|
||||
sfa=act_input_sf,
|
||||
@ -460,9 +523,33 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
)
|
||||
act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
|
||||
input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True)
|
||||
h3 = deepgemm_fp8_group_blockwise_gemm(
|
||||
|
||||
# activation and quantization
|
||||
act_input_fp8 = set_strides(workspace["workspace_0"],
|
||||
self.expert_size_per_partition, m_max,
|
||||
self.intermediate_size)
|
||||
|
||||
scale_k = fp8_utils.ceil_div(self.intermediate_size, 128)
|
||||
scale_k_padded = fp8_utils.align(scale_k, 4)
|
||||
act_input_sf = set_strides(workspace["workspace_sf"],
|
||||
self.expert_size_per_partition,
|
||||
scale_k_padded // 4, m_padded)
|
||||
|
||||
act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
|
||||
output=act_input_fp8,
|
||||
output_scale=act_input_sf,
|
||||
input=h1,
|
||||
quant_group_size=128,
|
||||
masked_m=masked_m,
|
||||
scale_ue8m0=True)
|
||||
|
||||
# grouped gemm 2
|
||||
h3 = set_strides(workspace["workspace_1"],
|
||||
self.expert_size_per_partition, m_max,
|
||||
self.hidden_size)
|
||||
|
||||
deepgemm_fp8_group_blockwise_gemm(
|
||||
d=h3,
|
||||
a=act_input_fp8,
|
||||
b=self.w2_weight,
|
||||
sfa=act_input_sf,
|
||||
@ -471,6 +558,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
expected_m=expected_m,
|
||||
)
|
||||
|
||||
# gather and finalize
|
||||
triton_masked_index_gather(permuted_data_tensor, h3,
|
||||
expert_first_token_offset_tensor,
|
||||
token_to_expert_map)
|
||||
@ -495,3 +583,137 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
do_finalize: bool = True, # used by other MoE backends
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
|
||||
if self.use_dp and self.parallel_size > 1:
|
||||
assert all_rank_num_tokens is not None
|
||||
assert use_dp_padding is not None
|
||||
num_rows = sum(all_rank_num_tokens)
|
||||
else:
|
||||
num_rows = x.shape[0]
|
||||
|
||||
# In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
|
||||
# Because we will use two streams in chunked moe and preallocate two workspaces.
|
||||
num_chunks = 1
|
||||
if num_rows > self.moe_max_num_tokens * 2:
|
||||
num_chunks = (num_rows + self.moe_max_num_tokens -
|
||||
1) // self.moe_max_num_tokens
|
||||
|
||||
if use_dp_padding:
|
||||
all_rank_num_tokens_padded = [all_rank_max_num_tokens
|
||||
] * len(all_rank_num_tokens)
|
||||
else:
|
||||
all_rank_num_tokens_padded = all_rank_num_tokens
|
||||
|
||||
if num_chunks == 1:
|
||||
# create workspace
|
||||
num_rows = x.shape[0]
|
||||
if self.use_dp:
|
||||
num_rows = sum(all_rank_num_tokens_padded)
|
||||
m_max = fp8_utils.align(num_rows, 128)
|
||||
workspace = self.get_workspace(m_max, 128)
|
||||
outputs = self.forward_chunk(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens_padded,
|
||||
use_dp_padding=use_dp_padding,
|
||||
workspace=workspace)
|
||||
outputs = self.reducescatter_or_allreduce(
|
||||
outputs,
|
||||
all_rank_num_tokens=all_rank_num_tokens_padded,
|
||||
use_dp_padding=use_dp_padding)
|
||||
else:
|
||||
if self.use_dp:
|
||||
all_rank_chunk_size_list = [
|
||||
self.split_chunk(val, num_chunks)
|
||||
for val in all_rank_num_tokens_padded
|
||||
]
|
||||
all_rank_num_tokens_list = [[
|
||||
val[idx_chunk] for val in all_rank_chunk_size_list
|
||||
] for idx_chunk in range(num_chunks)]
|
||||
chunk_size_list = all_rank_chunk_size_list[self.rank]
|
||||
else:
|
||||
all_rank_num_tokens_list = [None] * num_chunks
|
||||
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
|
||||
|
||||
# create workspace
|
||||
chunk_size_0 = sum(all_rank_num_tokens_list[0]
|
||||
) if self.use_dp else chunk_size_list[0]
|
||||
chunk_size_1 = sum(all_rank_num_tokens_list[1]
|
||||
) if self.use_dp else chunk_size_list[1]
|
||||
workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128),
|
||||
128)
|
||||
workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128),
|
||||
128)
|
||||
|
||||
x_list = x.split(chunk_size_list)
|
||||
router_logits_list = router_logits.split(chunk_size_list)
|
||||
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
|
||||
def _forward_chunk(x_, router_logits_, idx, workspace):
|
||||
return self.forward_chunk(
|
||||
x_,
|
||||
router_logits_,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[idx]
|
||||
if self.use_dp else None,
|
||||
use_dp_padding=use_dp_padding,
|
||||
workspace=workspace)
|
||||
|
||||
def _reducescatter_or_allreduce(x_, idx):
|
||||
return self.reducescatter_or_allreduce(
|
||||
x_,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[idx],
|
||||
use_dp_padding=use_dp_padding)
|
||||
|
||||
outputs_list = []
|
||||
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
|
||||
for idx_chunk, (x, router_logits) in enumerate(
|
||||
zip(x_list, router_logits_list)):
|
||||
|
||||
if idx_chunk % 2 == 0:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs = _forward_chunk(x, router_logits, idx_chunk,
|
||||
workspace_0)
|
||||
if idx_chunk > 0:
|
||||
outputs_list[-1] = _reducescatter_or_allreduce(
|
||||
outputs_list[-1], idx_chunk - 1)
|
||||
else:
|
||||
outputs = _forward_chunk(x, router_logits, idx_chunk,
|
||||
workspace_1)
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs_list[-1] = _reducescatter_or_allreduce(
|
||||
outputs_list[-1], idx_chunk - 1)
|
||||
|
||||
outputs_list.append(outputs)
|
||||
|
||||
if num_chunks % 2 == 0:
|
||||
outputs_list[-1] = _reducescatter_or_allreduce(
|
||||
outputs_list[-1], -1)
|
||||
else:
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
outputs_list[-1] = _reducescatter_or_allreduce(
|
||||
outputs_list[-1], -1)
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.MoeChunkingOverlap].record()
|
||||
self.event_dict[EventType.MoeChunkingOverlap].wait()
|
||||
|
||||
outputs = torch.cat(outputs_list)
|
||||
|
||||
if self.use_dp and self.parallel_size > 1:
|
||||
rank = self.mapping.tp_rank
|
||||
outputs = outputs[:all_rank_num_tokens[rank]]
|
||||
return outputs
|
||||
|
||||
@ -81,13 +81,9 @@ class VanillaMoE(nn.ModuleList):
|
||||
self.num_experts)
|
||||
self.expert_size_per_partition = self.expert_end - self.expert_start
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
if self.use_dp:
|
||||
max_num_tokens *= model_config.mapping.world_size
|
||||
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
|
||||
if model_config.moe_max_num_tokens
|
||||
is not None else max_num_tokens)
|
||||
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
|
||||
|
||||
self._weights_created = False
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
|
||||
@ -150,12 +150,11 @@ class WideEPMoE(MoE):
|
||||
assert len(
|
||||
self.initial_local_expert_ids) == self.expert_size_per_partition
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
max_num_tokens *= model_config.mapping.world_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
|
||||
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
|
||||
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
|
||||
if self.moe_max_num_tokens < max_num_tokens:
|
||||
if self.moe_max_num_tokens < moe_max_num_tokens:
|
||||
self.aux_stream = aux_stream_dict[
|
||||
AuxStreamType.
|
||||
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(
|
||||
|
||||
@ -372,6 +372,10 @@ class Mapping(object):
|
||||
def local_rank(self):
|
||||
return self.rank % self.gpus_per_node
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.tp_size if self.enable_attention_dp else 1
|
||||
|
||||
def has_cp(self):
|
||||
return self.cp_size > 1
|
||||
|
||||
|
||||
@ -302,6 +302,8 @@ def _silu_and_mul_post_quant_kernel(
|
||||
|
||||
|
||||
def silu_and_mul_masked_post_quant_fwd(
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
quant_group_size: int,
|
||||
masked_m: torch.Tensor,
|
||||
@ -328,18 +330,6 @@ def silu_and_mul_masked_post_quant_fwd(
|
||||
g, m, k = input.shape
|
||||
k = k // 2
|
||||
|
||||
# Create output
|
||||
output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
|
||||
# Create output scale
|
||||
alignment = 4
|
||||
scale_k = ceil_div(k, quant_group_size)
|
||||
m_padded = align(m, alignment)
|
||||
scale_k_padded = align(scale_k, alignment)
|
||||
output_scale = torch.zeros((g, scale_k_padded // 4, m_padded),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
# Get block/grid/stage/warp
|
||||
expert_num = len(masked_m)
|
||||
|
||||
@ -382,7 +372,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
||||
g,
|
||||
tma_stride_check=True,
|
||||
)
|
||||
return output, output_scale
|
||||
return output_scale
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Loading…
Reference in New Issue
Block a user