feat: large-scale EP(part 3 - refactor: FusedMoe for redundant expert) (#4495)

refactor fused_moe for redundant expert

Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
dongxuy04 2025-05-21 17:17:49 +08:00 committed by GitHub
parent 407ef08662
commit 4018806742
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -272,6 +272,10 @@ class FusedMoE(nn.Module):
In min-latency mode, setting `reduce_results=False` disables the AllReduce in the FusedMoE module, so any necessary AllReduce operations must be added explicitly in the model definition.
AttentionDP should be turned off for min-latency mode.
When we have redundant expert, we have more weight slots than `num_experts`, in that case, we separate the concepts of expert and slot.
Expert is the concept from model's perspective while slot is the concept from model engine's perspective.
There should be at lease `num_experts` slots in the model engine. More than that is OK, in that case, some experts may have multiple replicas.
"""
def __init__(
@ -326,11 +330,25 @@ class FusedMoE(nn.Module):
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.expert_size_per_partition = num_experts // self.ep_size
self.expert_start = self.ep_rank * self.expert_size_per_partition
self.expert_end = min(
self.expert_start + self.expert_size_per_partition,
self.num_experts)
# self.expert_slots_per_partition will be replaced with real slots_per_partition to enable redundant expert slots
self.expert_slots_per_partition = num_experts // self.ep_size
assert self.expert_slots_per_partition * self.ep_size >= num_experts, "total slots should be at lease num_experts"
if self.smart_router:
assert self.expert_slots_per_partition == num_experts // self.ep_size,\
"Smart router should not have redundant slots"
self.num_slots = self.expert_slots_per_partition * self.ep_size
# Here the meaning of expert_size_per_partition is the number of expert slots that each rank has.
self.expert_size_per_partition = self.expert_slots_per_partition
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
self.initial_global_assignments = [
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
self.num_experts for ep_rank in range(self.ep_size)
for local_slot_id in range(self.expert_slots_per_partition)
]
self.initial_local_expert_ids = self.initial_global_assignments[
self.slot_start:self.slot_end]
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
@ -354,7 +372,7 @@ class FusedMoE(nn.Module):
# around 16k tokens per expert, which is well into the compute bound domain.
self.tune_max_num_tokens = min(
self.moe_max_num_tokens,
16384 * num_experts // routing_method.get_experts_per_token(),
16384 * self.num_slots // routing_method.get_experts_per_token(),
)
self.has_been_profiled = False
self.has_been_profiled_min_latency = False
@ -453,45 +471,45 @@ class FusedMoE(nn.Module):
def is_cutlass(self):
return not self.is_trtllm()
def get_quant_scales(self, expert_start, expert_end):
def get_quant_scales(self, slot_start, slot_end):
assert self.smart_router
if self.has_fp8_block_scales:
return FusedMoEQuantScalesFP8BlockScales(
fc_weight_scales=self.w3_w1_weight_scaling_factor.narrow(
0, expert_start, expert_end - expert_start),
0, slot_start, slot_end - slot_start),
proj_weight_scales=self.w2_weight_scaling_factor.narrow(
0, expert_start, expert_end - expert_start),
0, slot_start, slot_end - slot_start),
)
elif self.has_nvfp4:
return FusedMoEQuantScalesNVFP4(
fc1_act_global=self.fc31_input_scale,
fc1_weight_block=self.w3_w1_weight_scale.narrow(
0, expert_start, expert_end - expert_start),
fc1_global=self.fc31_alpha.narrow(0, expert_start,
expert_end - expert_start),
0, slot_start, slot_end - slot_start),
fc1_global=self.fc31_alpha.narrow(0, slot_start,
slot_end - slot_start),
fc2_act_global=self.fc2_input_scale,
fc2_weight_block=self.w2_weight_scale.narrow(
0, expert_start, expert_end - expert_start),
fc2_global=self.fc2_alpha.narrow(0, expert_start,
expert_end - expert_start),
0, slot_start, slot_end - slot_start),
fc2_global=self.fc2_alpha.narrow(0, slot_start,
slot_end - slot_start),
)
elif self.has_w4afp8:
return FusedMoEQuantScalesW4A8(
scale_1_interleaved=self.fc31_weight_scale.narrow(
0, expert_start, expert_end - expert_start),
0, slot_start, slot_end - slot_start),
scale_2_interleaved=self.fc2_weight_scale.narrow(
0, expert_start, expert_end - expert_start),
0, slot_start, slot_end - slot_start),
pre_quant_scale_1=self.fc31_act_scale.narrow(
0, expert_start, expert_end - expert_start),
0, slot_start, slot_end - slot_start),
pre_quant_scale_2=self.fc2_act_scale.narrow(
0, expert_start, expert_end - expert_start),
0, slot_start, slot_end - slot_start),
zero_1=torch.Tensor(),
zero_2=torch.Tensor(),
alpha_1=self.fc31_alpha.narrow(0, expert_start,
expert_end - expert_start),
alpha_2=self.fc2_alpha.narrow(0, expert_start,
expert_end - expert_start),
alpha_1=self.fc31_alpha.narrow(0, slot_start,
slot_end - slot_start),
alpha_2=self.fc2_alpha.narrow(0, slot_start,
slot_end - slot_start),
)
else:
return self.quant_scales
@ -1081,13 +1099,13 @@ class FusedMoE(nn.Module):
self.w3_w1_weight_scaling_factor,
self.w2_weight,
self.w2_weight_scaling_factor,
self.num_experts,
self.num_slots,
top_k,
n_group,
topk_group,
self.intermediate_size_per_partition,
self.
expert_start, # local_expert_start; use ep_rank if stride!=1
slot_start, # local_expert_start; use ep_rank if stride!=1
self.expert_size_per_partition, # local_expert_size
routed_scaling_factor,
)
@ -1110,13 +1128,13 @@ class FusedMoE(nn.Module):
self.fc31_scale_c.data,
self.fc31_alpha.data,
self.fc2_alpha.data,
self.num_experts,
self.num_slots,
top_k,
n_group,
topk_group,
self.intermediate_size_per_partition,
self.
expert_start, # local_expert_start; use ep_rank if stride!=1
slot_start, # local_expert_start; use ep_rank if stride!=1
self.expert_size_per_partition, # local_expert_size
routed_scaling_factor,
)
@ -1282,8 +1300,10 @@ class FusedMoE(nn.Module):
# CPU memory bandwidth better.
threads = []
for expert_id in range(self.expert_start, self.expert_end):
expert_idx = expert_id - self.expert_start
for local_slot_id, expert_id in enumerate(
self.initial_local_expert_ids):
# expert_idx is the local slot index of current rank
expert_idx = local_slot_id
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight = weights[f"{expert_id}.w1.weight"]
@ -1342,7 +1362,7 @@ class FusedMoE(nn.Module):
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.ROW)
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
w2_scales = torch.stack(all_w2_scales)
@ -1352,14 +1372,14 @@ class FusedMoE(nn.Module):
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
w3_w1_scales = torch.cat(
@ -1458,7 +1478,8 @@ class FusedMoE(nn.Module):
dst_w2_weight_scale: torch.Tensor):
dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([]))
for expert_id in range(self.expert_start, self.expert_end):
for local_slot_id, expert_id in enumerate(
self.initial_local_expert_ids):
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
@ -1472,7 +1493,7 @@ class FusedMoE(nn.Module):
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
)
expert_idx = expert_id - self.expert_start
expert_idx = local_slot_id
load_expert_w3_w1_weight_scale_fp8_qdq(
w1_weight_scale, w3_weight_scale,
@ -1617,7 +1638,8 @@ class FusedMoE(nn.Module):
dst_w2_alpha.copy_(1.0 /
(final_fc2_input_scale * w2_weight_scale_2))
for expert_id in range(self.expert_start, self.expert_end):
for local_slot_id, expert_id in enumerate(
self.initial_local_expert_ids):
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
@ -1640,7 +1662,7 @@ class FusedMoE(nn.Module):
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
)
expert_idx = expert_id - self.expert_start
expert_idx = local_slot_id
load_expert_w3_w1_weight_scale_nvfp4(
w1_weight_scale, w3_weight_scale,
@ -1664,11 +1686,11 @@ class FusedMoE(nn.Module):
assert (len(self.interleave) == 2)
all_w3_input_scales = [
load_weight_shard(weights[f"{expert_id}.w3.input_scale"])
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
all_w1_input_scales = [
load_weight_shard(weights[f"{expert_id}.w1.input_scale"])
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
all_w3_w1_input_scales = torch.max(torch.stack(all_w3_input_scales),
torch.stack(all_w1_input_scales))
@ -1681,13 +1703,13 @@ class FusedMoE(nn.Module):
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
@ -1709,7 +1731,7 @@ class FusedMoE(nn.Module):
# fc2 scales
all_w2_input_scales = [
load_weight_shard(weights[f"{expert_id}.w2.input_scale"])
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
all_w2_input_scales = torch.stack(all_w2_input_scales).to(self.dtype)
all_w2_input_scales = torch.ones_like(
@ -1721,7 +1743,7 @@ class FusedMoE(nn.Module):
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.ROW)
for expert_id in range(self.expert_start, self.expert_end)
for expert_id in self.initial_local_expert_ids
]
if self.sm_version == 89:
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(