mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: large-scale EP(part 3 - refactor: FusedMoe for redundant expert) (#4495)
refactor fused_moe for redundant expert Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
This commit is contained in:
parent
407ef08662
commit
4018806742
@ -272,6 +272,10 @@ class FusedMoE(nn.Module):
|
||||
|
||||
In min-latency mode, setting `reduce_results=False` disables the AllReduce in the FusedMoE module, so any necessary AllReduce operations must be added explicitly in the model definition.
|
||||
AttentionDP should be turned off for min-latency mode.
|
||||
|
||||
When we have redundant expert, we have more weight slots than `num_experts`, in that case, we separate the concepts of expert and slot.
|
||||
Expert is the concept from model's perspective while slot is the concept from model engine's perspective.
|
||||
There should be at lease `num_experts` slots in the model engine. More than that is OK, in that case, some experts may have multiple replicas.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -326,11 +330,25 @@ class FusedMoE(nn.Module):
|
||||
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
|
||||
self.expert_size_per_partition = num_experts // self.ep_size
|
||||
self.expert_start = self.ep_rank * self.expert_size_per_partition
|
||||
self.expert_end = min(
|
||||
self.expert_start + self.expert_size_per_partition,
|
||||
self.num_experts)
|
||||
# self.expert_slots_per_partition will be replaced with real slots_per_partition to enable redundant expert slots
|
||||
self.expert_slots_per_partition = num_experts // self.ep_size
|
||||
assert self.expert_slots_per_partition * self.ep_size >= num_experts, "total slots should be at lease num_experts"
|
||||
if self.smart_router:
|
||||
assert self.expert_slots_per_partition == num_experts // self.ep_size,\
|
||||
"Smart router should not have redundant slots"
|
||||
self.num_slots = self.expert_slots_per_partition * self.ep_size
|
||||
# Here the meaning of expert_size_per_partition is the number of expert slots that each rank has.
|
||||
self.expert_size_per_partition = self.expert_slots_per_partition
|
||||
self.slot_start = self.ep_rank * self.expert_size_per_partition
|
||||
self.slot_end = self.slot_start + self.expert_size_per_partition
|
||||
|
||||
self.initial_global_assignments = [
|
||||
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
|
||||
self.num_experts for ep_rank in range(self.ep_size)
|
||||
for local_slot_id in range(self.expert_slots_per_partition)
|
||||
]
|
||||
self.initial_local_expert_ids = self.initial_global_assignments[
|
||||
self.slot_start:self.slot_end]
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
@ -354,7 +372,7 @@ class FusedMoE(nn.Module):
|
||||
# around 16k tokens per expert, which is well into the compute bound domain.
|
||||
self.tune_max_num_tokens = min(
|
||||
self.moe_max_num_tokens,
|
||||
16384 * num_experts // routing_method.get_experts_per_token(),
|
||||
16384 * self.num_slots // routing_method.get_experts_per_token(),
|
||||
)
|
||||
self.has_been_profiled = False
|
||||
self.has_been_profiled_min_latency = False
|
||||
@ -453,45 +471,45 @@ class FusedMoE(nn.Module):
|
||||
def is_cutlass(self):
|
||||
return not self.is_trtllm()
|
||||
|
||||
def get_quant_scales(self, expert_start, expert_end):
|
||||
def get_quant_scales(self, slot_start, slot_end):
|
||||
assert self.smart_router
|
||||
|
||||
if self.has_fp8_block_scales:
|
||||
return FusedMoEQuantScalesFP8BlockScales(
|
||||
fc_weight_scales=self.w3_w1_weight_scaling_factor.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
proj_weight_scales=self.w2_weight_scaling_factor.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
)
|
||||
elif self.has_nvfp4:
|
||||
return FusedMoEQuantScalesNVFP4(
|
||||
fc1_act_global=self.fc31_input_scale,
|
||||
fc1_weight_block=self.w3_w1_weight_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
fc1_global=self.fc31_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc1_global=self.fc31_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
fc2_act_global=self.fc2_input_scale,
|
||||
fc2_weight_block=self.w2_weight_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
fc2_global=self.fc2_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
fc2_global=self.fc2_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
)
|
||||
elif self.has_w4afp8:
|
||||
return FusedMoEQuantScalesW4A8(
|
||||
scale_1_interleaved=self.fc31_weight_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
scale_2_interleaved=self.fc2_weight_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
pre_quant_scale_1=self.fc31_act_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
pre_quant_scale_2=self.fc2_act_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
0, slot_start, slot_end - slot_start),
|
||||
zero_1=torch.Tensor(),
|
||||
zero_2=torch.Tensor(),
|
||||
alpha_1=self.fc31_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
alpha_2=self.fc2_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
alpha_1=self.fc31_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
alpha_2=self.fc2_alpha.narrow(0, slot_start,
|
||||
slot_end - slot_start),
|
||||
)
|
||||
else:
|
||||
return self.quant_scales
|
||||
@ -1081,13 +1099,13 @@ class FusedMoE(nn.Module):
|
||||
self.w3_w1_weight_scaling_factor,
|
||||
self.w2_weight,
|
||||
self.w2_weight_scaling_factor,
|
||||
self.num_experts,
|
||||
self.num_slots,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
self.intermediate_size_per_partition,
|
||||
self.
|
||||
expert_start, # local_expert_start; use ep_rank if stride!=1
|
||||
slot_start, # local_expert_start; use ep_rank if stride!=1
|
||||
self.expert_size_per_partition, # local_expert_size
|
||||
routed_scaling_factor,
|
||||
)
|
||||
@ -1110,13 +1128,13 @@ class FusedMoE(nn.Module):
|
||||
self.fc31_scale_c.data,
|
||||
self.fc31_alpha.data,
|
||||
self.fc2_alpha.data,
|
||||
self.num_experts,
|
||||
self.num_slots,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
self.intermediate_size_per_partition,
|
||||
self.
|
||||
expert_start, # local_expert_start; use ep_rank if stride!=1
|
||||
slot_start, # local_expert_start; use ep_rank if stride!=1
|
||||
self.expert_size_per_partition, # local_expert_size
|
||||
routed_scaling_factor,
|
||||
)
|
||||
@ -1282,8 +1300,10 @@ class FusedMoE(nn.Module):
|
||||
# CPU memory bandwidth better.
|
||||
threads = []
|
||||
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
expert_idx = expert_id - self.expert_start
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
self.initial_local_expert_ids):
|
||||
# expert_idx is the local slot index of current rank
|
||||
expert_idx = local_slot_id
|
||||
|
||||
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight = weights[f"{expert_id}.w1.weight"]
|
||||
@ -1342,7 +1362,7 @@ class FusedMoE(nn.Module):
|
||||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.ROW)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
|
||||
w2_scales = torch.stack(all_w2_scales)
|
||||
@ -1352,14 +1372,14 @@ class FusedMoE(nn.Module):
|
||||
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.COLUMN)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
|
||||
all_w1_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.COLUMN)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
|
||||
w3_w1_scales = torch.cat(
|
||||
@ -1458,7 +1478,8 @@ class FusedMoE(nn.Module):
|
||||
dst_w2_weight_scale: torch.Tensor):
|
||||
dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([]))
|
||||
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
self.initial_local_expert_ids):
|
||||
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
@ -1472,7 +1493,7 @@ class FusedMoE(nn.Module):
|
||||
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
|
||||
)
|
||||
|
||||
expert_idx = expert_id - self.expert_start
|
||||
expert_idx = local_slot_id
|
||||
|
||||
load_expert_w3_w1_weight_scale_fp8_qdq(
|
||||
w1_weight_scale, w3_weight_scale,
|
||||
@ -1617,7 +1638,8 @@ class FusedMoE(nn.Module):
|
||||
dst_w2_alpha.copy_(1.0 /
|
||||
(final_fc2_input_scale * w2_weight_scale_2))
|
||||
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
self.initial_local_expert_ids):
|
||||
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
@ -1640,7 +1662,7 @@ class FusedMoE(nn.Module):
|
||||
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
|
||||
)
|
||||
|
||||
expert_idx = expert_id - self.expert_start
|
||||
expert_idx = local_slot_id
|
||||
|
||||
load_expert_w3_w1_weight_scale_nvfp4(
|
||||
w1_weight_scale, w3_weight_scale,
|
||||
@ -1664,11 +1686,11 @@ class FusedMoE(nn.Module):
|
||||
assert (len(self.interleave) == 2)
|
||||
all_w3_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w3.input_scale"])
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
all_w1_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w1.input_scale"])
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
all_w3_w1_input_scales = torch.max(torch.stack(all_w3_input_scales),
|
||||
torch.stack(all_w1_input_scales))
|
||||
@ -1681,13 +1703,13 @@ class FusedMoE(nn.Module):
|
||||
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.COLUMN)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
all_w1_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.COLUMN)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
all_w3_w1_scales = torch.cat(
|
||||
[torch.stack(all_w3_scales),
|
||||
@ -1709,7 +1731,7 @@ class FusedMoE(nn.Module):
|
||||
# fc2 scales
|
||||
all_w2_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w2.input_scale"])
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
all_w2_input_scales = torch.stack(all_w2_input_scales).to(self.dtype)
|
||||
all_w2_input_scales = torch.ones_like(
|
||||
@ -1721,7 +1743,7 @@ class FusedMoE(nn.Module):
|
||||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.ROW)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
for expert_id in self.initial_local_expert_ids
|
||||
]
|
||||
if self.sm_version == 89:
|
||||
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user