chore: some refactor on WideEP (#5727)

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
dongxuy04 2025-07-09 14:26:57 +08:00 committed by GitHub
parent 64fd64fcf2
commit dd3c736c7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,8 +11,7 @@ from tensorrt_llm.mapping import Mapping
from ...distributed import allgather, reducescatter
from ...expert_statistic import ExpertStatistic
from ...model_config import ModelConfig
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
reswizzle_sf, swizzle_sf, unswizzle_sf)
from ...utils import EventType, Fp4QuantizedTensor, swizzle_sf
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE
from .moe_load_balancer import get_moe_load_balancer
@ -83,6 +82,12 @@ class WideEPMoE(MoE):
weight_loading_mode=weight_loading_mode,
)
assert self.use_dp, "Attention DP should be used with WideEP."
assert self.parallel_size > 1, "WideEP should only be enabled with parallel_size > 1"
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
assert self.apply_router_weight_on_input is False, "WideEP doesn't support apply_router_weight_on_input."
self.layer_idx = layer_idx
moe_load_balancer = get_moe_load_balancer()
@ -128,9 +133,6 @@ class WideEPMoE(MoE):
self.expert_size_per_partition = num_experts // self.ep_size
self.num_slots = num_experts
if self.smart_router:
assert self.num_slots == self.num_experts, "Smart router should not have redundant slots"
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
self.initial_local_expert_ids = self.initial_global_assignments[
@ -140,8 +142,7 @@ class WideEPMoE(MoE):
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.use_dp:
max_num_tokens *= model_config.mapping.world_size
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
if self.moe_max_num_tokens < max_num_tokens:
@ -163,7 +164,6 @@ class WideEPMoE(MoE):
16384 * self.num_slots // routing_method.get_experts_per_token(),
)
self.has_been_profiled = False
self.has_been_profiled_min_latency = False
self.alltoall_method_type = self.select_alltoall_method_type(
model_config.mapping, routing_method.experts_per_token, dtype,
@ -173,14 +173,15 @@ class WideEPMoE(MoE):
key="alltoall_method_type")
self.use_postquant_alltoall = False
if self.enable_alltoall:
assert self.use_dp and self.parallel_size > 1,\
"alltoall should only enabled with attention dp and parallel_size > 1"
qm = self.quant_config.quant_mode
self.use_postquant_alltoall = (os.environ.get(
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
== "1") and qm.has_nvfp4()
self.enable_alltoall_without_allgather = os.environ.get(
"TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER", "0") == "1"
# TODO: support alltoall without allgather for top_k % 4 != 0
self.enable_alltoall_without_allgather = (
os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER",
"1") == "1"
) and self.alltoall_method_type == AlltoallMethodType.MNNVL and routing_method.experts_per_token % 4 == 0
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
@ -204,9 +205,6 @@ class WideEPMoE(MoE):
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
self._weights_created = False
if not model_config.skip_create_weights_in_init:
self.create_weights()
@ -218,12 +216,6 @@ class WideEPMoE(MoE):
def _check_configs(self):
assert self._weights_created
if self.enable_alltoall:
assert self.use_dp and self.parallel_size > 1,\
"alltoall should only enabled with attention dp and parallel_size > 1"
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):
if not (self.quant_config.quant_mode.has_nvfp4()
@ -323,24 +315,20 @@ class WideEPMoE(MoE):
use_dp_padding: Optional[bool] = None,
):
outputs = inputs
if self.parallel_size > 1 and not self.enable_alltoall:
if self.use_dp:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
outputs = reducescatter(
inputs,
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
elif self.reduce_results:
outputs = self.all_reduce(inputs)
if not self.enable_alltoall:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
outputs = reducescatter(
inputs,
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
return outputs
def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
@ -373,17 +361,12 @@ class WideEPMoE(MoE):
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
x = x * token_final_scales.to(x.dtype)
# TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call:
self.layer_load_balancer.maybe_cudagraph_done_wait()
use_allgather = not self.enable_alltoall
loadbalancer_local_statistic_info = None
gathered_loadbalancer_local_statistic_info = None
token_selected_experts_for_statistic = None
@ -464,29 +447,37 @@ class WideEPMoE(MoE):
)
x_sf = None
x_is_sf_swizzled = x.is_sf_swizzled if isinstance(
x, Fp4QuantizedTensor) else False
x_row = x.shape[0]
x_col = x.shape[1]
sf_swizzle = True
if self.has_any_quant:
if self.has_fp8_qdq:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant)
elif self.has_nvfp4:
if not disable_fp4_allgather() or self.use_postquant_alltoall:
if use_allgather or self.use_postquant_alltoall:
if isinstance(x, Fp4QuantizedTensor):
if use_allgather:
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
x, x_sf = x.fp4_tensor, x.scaling_factor
x_row = x.shape[0]
# note: we use uint8 to store 2 fp4 values
x_col = x.shape[1] * 2
else:
sf_swizzle = not self.use_postquant_alltoall
# for both postquant alltoall and allgather, we need non swizzle layout
needed_sf_swizzle = False
x_row = x.shape[0]
x_col = x.shape[1]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size,
False, sf_swizzle)
x,
self.fc31_input_scale,
self.scaling_vector_size,
sfUseUE8M0=False,
swizzedLayout=needed_sf_swizzle)
if self.use_postquant_alltoall:
x_sf = x_sf.view((x_row, -1))
x_is_sf_swizzled = needed_sf_swizzle
elif self.has_deepseek_fp8_block_scales:
use_deepseek_fp8_block_scale = True
@ -498,8 +489,8 @@ class WideEPMoE(MoE):
f"unsupported quantization mode: {self.quant_config.quant_mode}"
)
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
) and not self.enable_alltoall:
if use_allgather:
# using allgather case.
if self.enable_dummy_allreduce:
self.dummy_allreduce()
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
@ -513,10 +504,11 @@ class WideEPMoE(MoE):
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
x_row = x.shape[0]
# Fp4 gemm has extra scaling factor
if x_sf is not None:
x_sf = reswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
assert not x_is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
@ -532,42 +524,22 @@ class WideEPMoE(MoE):
is_first_stage=is_first_call,
is_last_stage=is_last_call)
if self.smart_router and not cutlass_min_latency_mode:
ep_size = self.cluster_size
ep_rank = self.cluster_rank
expert_start = ep_rank * self.num_experts // ep_size
expert_end = min(self.num_experts,
(ep_rank + 1) * self.num_experts // ep_size)
w3_w1_weight = self.w3_w1_weight.narrow(0, expert_start,
expert_end - expert_start)
w2_weight = self.w2_weight.narrow(0, expert_start,
expert_end - expert_start)
cluster_size = self.ep_size
cluster_rank = self.ep_rank
quant_scales = self.get_quant_scales(expert_start, expert_end)
else:
ep_size = self.ep_size
ep_rank = self.ep_rank
w3_w1_weight = self.w3_w1_weight
w2_weight = self.w2_weight
cluster_size = self.cluster_size
cluster_rank = self.cluster_rank
quant_scales = self.quant_scales
ep_size = self.ep_size
ep_rank = self.ep_rank
w3_w1_weight = self.w3_w1_weight
w2_weight = self.w2_weight
cluster_size = self.cluster_size
cluster_rank = self.cluster_rank
quant_scales = self.quant_scales
if self.use_postquant_alltoall:
if x_sf is not None and self.has_nvfp4:
assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall"
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
x, x_sf = self.alltoall_postquant_dispatch(
x,
x_sf,
x_row,
x_col,
alltoall_info,
is_sf_swizzle=sf_swizzle)
x, x_sf, alltoall_info)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if x_sf is not None:
if self.has_nvfp4 and sf_swizzle:
x_sf = unswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
# Adapter between `x_sf` and DeepEP
# TODO: remove the adapter by adding dtype support to DeepEP
x_sf_dtype = x_sf.dtype
@ -631,7 +603,7 @@ class WideEPMoE(MoE):
enable_alltoall=self.enable_alltoall,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=cutlass_min_latency_mode,
min_latency_mode=False,
tune_max_num_tokens=self.tune_max_num_tokens,
)
@ -639,14 +611,9 @@ class WideEPMoE(MoE):
) and is_last_call:
self.layer_load_balancer.set_cpu_stage()
if cutlass_min_latency_mode:
assert not self.reduce_results
assert not self.enable_alltoall
else:
# Custom op requires all inputs are in the same type.
# Only in cutlass_min_latency_mode, the output is a list of tensors.
# Otherwise, the output should be unpacked as a single tensor.
final_hidden_states = final_hidden_states[0]
# Only in cutlass_min_latency_mode, the output is a list of tensors.
# Otherwise, the output should be unpacked as a single tensor.
final_hidden_states = final_hidden_states[0]
if self.enable_alltoall:
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
@ -702,23 +669,13 @@ class WideEPMoE(MoE):
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
) -> torch.Tensor:
if self.use_dp:
assert all_rank_num_tokens is not None
assert use_dp_padding is not None
num_rows = sum(all_rank_num_tokens)
else:
num_rows = x.shape[0]
assert all_rank_num_tokens is not None
assert use_dp_padding is not None
num_rows = sum(all_rank_num_tokens)
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
num_chunks = (num_rows + self.moe_max_num_tokens -
1) // self.moe_max_num_tokens
# TODO: remove cutlass_min_latency_mode since it is not used anymore
cutlass_min_latency_mode = not do_finalize
if cutlass_min_latency_mode:
assert num_chunks == 1 and (
not self.reduce_results
), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False"
if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens
@ -731,7 +688,6 @@ class WideEPMoE(MoE):
outputs = self.forward_chunk(
x,
router_logits,
cutlass_min_latency_mode,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
all_rank_max_num_tokens=all_rank_max_num_tokens,
@ -750,29 +706,24 @@ class WideEPMoE(MoE):
split_num_chunks - val_mod)
return split_chunk_size_list
if self.use_dp:
all_rank_chunk_size_list = [
split_chunk(val, num_chunks)
for val in all_rank_num_tokens_padded
]
all_rank_chunk_size_list = [
split_chunk(val, num_chunks)
for val in all_rank_num_tokens_padded
]
all_rank_num_tokens_list = [[
val[idx_chunk] for val in all_rank_chunk_size_list
] for idx_chunk in range(num_chunks)]
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
num_chunks)
chunk_size_list = all_rank_chunk_size_list[self.rank]
if self.enable_alltoall:
all_rank_num_tokens_list = [[
val[idx_chunk] for val in all_rank_chunk_size_list
] for idx_chunk in range(num_chunks)]
all_rank_max_num_tokens_list = split_chunk(
all_rank_max_num_tokens, num_chunks)
chunk_size_list = all_rank_chunk_size_list[self.rank]
if self.enable_alltoall:
all_rank_num_tokens_list = [[
1 if val == 0 else val for val in val_list
] for val_list in all_rank_num_tokens_list]
all_rank_max_num_tokens_list = [
1 if val == 0 else val
for val in all_rank_max_num_tokens_list
]
else:
all_rank_num_tokens_list = [None] * num_chunks
all_rank_max_num_tokens_list = [None] * num_chunks
chunk_size_list = split_chunk(x.shape[0], num_chunks)
1 if val == 0 else val for val in val_list
] for val_list in all_rank_num_tokens_list]
all_rank_max_num_tokens_list = [
1 if val == 0 else val
for val in all_rank_max_num_tokens_list
]
x_list = x.split(chunk_size_list)
router_logits_list = router_logits.split(chunk_size_list)
@ -795,10 +746,9 @@ class WideEPMoE(MoE):
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk] if self.use_dp else None,
idx_chunk],
all_rank_max_num_tokens=
all_rank_max_num_tokens_list[idx_chunk]
if self.use_dp else None,
all_rank_max_num_tokens_list[idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
if idx_chunk > 0:
@ -812,9 +762,9 @@ class WideEPMoE(MoE):
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk] if self.use_dp else None,
idx_chunk],
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
idx_chunk] if self.use_dp else None,
idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
with torch.cuda.stream(self.aux_stream):
@ -827,10 +777,9 @@ class WideEPMoE(MoE):
outputs = self.forward_chunk(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk]
if self.use_dp else None,
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
idx_chunk] if self.use_dp else None,
idx_chunk],
repeating_info=(is_first_call, is_last_call))
outputs_list.append(outputs)
@ -850,9 +799,8 @@ class WideEPMoE(MoE):
self.event_dict[EventType.MoeChunkingOverlap].record()
self.event_dict[EventType.MoeChunkingOverlap].wait()
outputs = torch.cat(outputs_list)
if self.use_dp:
rank = self.mapping.tp_rank
outputs = outputs[:all_rank_num_tokens[rank]]
rank = self.mapping.tp_rank
outputs = outputs[:all_rank_num_tokens[rank]]
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
return outputs
@ -862,8 +810,8 @@ class WideEPMoE(MoE):
token_final_scales: torch.Tensor,
local_statistic_tensor: Optional[torch.Tensor]):
top_k = self.routing_method.experts_per_token
# TODO: support alltoall without allgather for top_k % 4 != 0
if self.enable_alltoall_without_allgather and top_k % 4 == 0:
if self.enable_alltoall_without_allgather:
alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_slots, token_final_scales,
local_statistic_tensor, self.alltoall_prepare_workspace,
@ -913,22 +861,13 @@ class WideEPMoE(MoE):
return x, token_selected_slots, token_final_scales, gathered_local_statistic_tensor, alltoall_info
def alltoall_postquant_dispatch(self,
x: torch.Tensor,
x_sf: torch.Tensor,
x_row: int,
x_col: int,
alltoall_info: MoEAlltoallInfo,
is_sf_swizzle: bool = True):
def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor,
alltoall_info: MoEAlltoallInfo):
x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info,
self.alltoall_workspace, self.ep_rank,
self.ep_size)
if x_sf is not None:
if self.has_nvfp4 and is_sf_swizzle:
x_sf = unswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info,
self.alltoall_workspace,
self.ep_rank, self.ep_size)