mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 09:41:30 +08:00
Overlap gen_fc2_alpha with fc1 using multistream in DenseGEMMFusedMoE
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
parent
5ddbe3ca76
commit
578c0a8e28
@ -9,7 +9,7 @@ from tensorrt_llm.quantization.utils import fp4_utils
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf
|
||||
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .interface import AlltoallMethodType
|
||||
from .quantization import MoEWeightLoadingMode
|
||||
@ -137,6 +137,16 @@ class DenseGEMMFusedMoE(CutlassFusedMoE):
|
||||
init_load_balancer=init_load_balancer,
|
||||
without_comm=without_comm,
|
||||
)
|
||||
# Initialize auxiliary stream and events for gen_fc2_alpha overlap with fc1
|
||||
if self.aux_stream_dict is None:
|
||||
self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {}
|
||||
if AuxStreamType.MoeFc2Alpha not in self.aux_stream_dict:
|
||||
self.aux_stream_dict[AuxStreamType.MoeFc2Alpha] = torch.cuda.Stream()
|
||||
if self.event_dict is None:
|
||||
self.event_dict = {}
|
||||
for key in [EventType.Main, EventType.MoeFc2Alpha]:
|
||||
if key not in self.event_dict:
|
||||
self.event_dict[key] = torch.cuda.Event()
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
super().load_weights(weights)
|
||||
@ -272,6 +282,7 @@ class DenseGEMMFusedMoE(CutlassFusedMoE):
|
||||
output_dtype = torch.bfloat16
|
||||
num_tokens = x.shape[0]
|
||||
|
||||
self.event_dict[EventType.Main].record()
|
||||
x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size)
|
||||
fc1_output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
|
||||
x,
|
||||
@ -282,17 +293,22 @@ class DenseGEMMFusedMoE(CutlassFusedMoE):
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
with torch.cuda.stream(self.aux_stream_dict[AuxStreamType.MoeFc2Alpha]):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
fc2_alpha = self.gen_fc2_alpha(
|
||||
num_tokens, token_selected_experts, token_final_scales, self.fc2_alpha
|
||||
)
|
||||
alpha_normal_factor = self.fc2_alpha_max
|
||||
alpha_normalized = fc2_alpha / alpha_normal_factor
|
||||
self.event_dict[EventType.MoeFc2Alpha].record()
|
||||
|
||||
fc1_output = fc1_output.view(
|
||||
[num_tokens, self.expert_size_per_partition, -1]
|
||||
) * self.fc31_alpha.view([1, -1, 1])
|
||||
fc1_output = fc1_output.to(output_dtype)
|
||||
|
||||
swiglu_out = swiglu_fused_moe(fc1_output)
|
||||
alpha = self.gen_fc2_alpha(
|
||||
num_tokens, token_selected_experts, token_final_scales, self.fc2_alpha
|
||||
)
|
||||
alpha_normal_factor = self.fc2_alpha_max
|
||||
alpha_normalized = alpha / alpha_normal_factor
|
||||
|
||||
self.event_dict[EventType.MoeFc2Alpha].wait()
|
||||
fc2_input = (
|
||||
(swiglu_out * alpha_normalized.view([num_tokens, self.expert_size_per_partition, 1]))
|
||||
.view([num_tokens, -1])
|
||||
|
||||
@ -22,6 +22,7 @@ aux_stream_name_list = [
|
||||
'MoeChunkingOverlap',
|
||||
'MoeBalancer',
|
||||
'MoeOutputMemset',
|
||||
'MoeFc2Alpha',
|
||||
]
|
||||
AuxStreamType = Enum(
|
||||
'AuxStreamType',
|
||||
|
||||
Loading…
Reference in New Issue
Block a user