[Feature] Support EPLB for DeepSeek v4 Mega Moe (#43339)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao (Engrg-Hardware 1) <weizha@login-lyris01.lyris.clusters.nvidia.com>
This commit is contained in:
Wei Zhao
2026-06-02 13:56:44 -04:00
committed by GitHub
parent fe32e7830b
commit 2427094152
4 changed files with 232 additions and 46 deletions
+16 -7
View File
@@ -61,25 +61,31 @@ class CpuGpuEvent:
self._recorded.set()
def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
def override_envs_for_eplb(
parallel_config: ParallelConfig,
moe_backend: str | None = None,
) -> None:
"""
Override environment variables for EPLB when specific conditions are met.
Args:
parallel_config: The parallel configuration object.
moe_backend: The configured MoE backend (e.g. ``deep_gemm_mega_moe``).
"""
is_data_parallel = parallel_config.data_parallel_size > 1
is_eplb_enabled = parallel_config.enable_eplb
async_eplb = parallel_config.eplb_config.use_async
is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency"
is_mega_moe = moe_backend == "deep_gemm_mega_moe"
is_nccl_based_eplb_communicator = parallel_config.eplb_config.communicator in (
"torch_nccl",
"pynccl",
)
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
# DeepEP low-latency backend.
# Override NCCL_MAX_CTAS to avoid hangs when EPLB's NCCL weight exchange
# contends with MoE backend's cooperative-launch on GPU SMs.
#
# DeepEP low-latency:
# The hang happens when two ranks interleave kernel launches differently
# between NCCL collectives (used by async EPLB weight exchange) and DeepEP
# low-latency (LL) kernels. DeepEP LL uses a cooperative launch and tries
@@ -94,12 +100,14 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
# cooperative kernel to launch and complete, breaking the deadlock.
# See: https://github.com/deepseek-ai/DeepEP/issues/496
#
# DeepGEMM Mega MoE also uses cooperative launch and will cause hang even
# with sync EPLB.
if (
is_data_parallel
and is_eplb_enabled
and is_deepep_ll
and async_eplb
and is_nccl_based_eplb_communicator
and ((is_deepep_ll and async_eplb) or is_mega_moe)
):
current_value_str = os.getenv("NCCL_MAX_CTAS")
@@ -108,9 +116,10 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
override_value = 8
os.environ["NCCL_MAX_CTAS"] = str(override_value)
backend = "deepep_low_latency" if is_deepep_ll else "deep_gemm_mega_moe"
logger.info_once(
f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
"for expert parallel with NCCL-based EPLB communicator and "
"deepep_low_latency backend",
f"for expert parallel with NCCL-based EPLB communicator and "
f"cooperative MoE backend ({backend})",
scope="global",
)
+211 -38
View File
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import typing
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, MutableSequence, Sequence
from itertools import islice
import regex as re
@@ -15,6 +15,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.kernels.mhc.tilelang import (
hc_head_fused_kernel_tilelang,
mhc_fused_post_pre_tilelang,
@@ -23,6 +24,9 @@ from vllm.model_executor.kernels.mhc.tilelang import (
)
from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.router.base_router import (
eplb_map_to_physical_and_record,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
@@ -40,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -144,6 +148,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
hidden_size: int,
intermediate_size: int,
prefix: str = "",
num_logical_experts: int | None = None,
):
super().__init__()
self.prefix = prefix
@@ -156,6 +161,12 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self.intermediate_size = intermediate_size
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.num_logical_experts = (
num_logical_experts if num_logical_experts is not None else num_experts
)
self.eplb_state = EplbLayerState()
weight_attrs = {"weight_loader": self.weight_loader}
self.w13_weight = nn.Parameter(
torch.zeros(
@@ -206,10 +217,22 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
def _map_global_expert_id(self, expert_id: int) -> int:
if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx:
return -1
return expert_id - self.experts_start_idx
# Register in the static forward context so the custom-op wrapper
# can look up this module by name from within a torch.compile graph.
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def _map_global_expert_id(self, expert_id: int) -> list[int]:
"""Return local (per-rank) slot offsets where logical expert
`expert_id` should land on this rank.
"""
physical_ids: list[int] = []
for p in range(self.experts_start_idx, self.experts_end_idx):
if p % self.num_logical_experts == expert_id:
physical_ids.append(p - self.experts_start_idx)
return physical_ids
def weight_loader(
self,
@@ -220,30 +243,38 @@ class DeepseekV4MegaMoEExperts(nn.Module):
expert_id: int,
return_success: bool = False,
) -> bool | None:
local_expert_id = self._map_global_expert_id(expert_id)
if local_expert_id == -1:
local_expert_ids = self._map_global_expert_id(expert_id)
if not local_expert_ids:
return False if return_success else None
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
return False if return_success else None
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
elif shard_id == "w2":
if "w2_" not in weight_name:
return False if return_success else None
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
loaded_any = False
for local_expert_id in local_expert_ids:
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
continue
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(
0, shard_offset, self.intermediate_size
)
elif shard_id == "w2":
if "w2_" not in weight_name:
continue
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
if expert_data.shape != loaded_weight.shape:
raise ValueError(
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
return True if return_success else None
if expert_data.shape != loaded_weight.shape:
raise ValueError(
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
loaded_any = True
if return_success:
return loaded_any
return None
@staticmethod
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
@@ -264,7 +295,9 @@ class DeepseekV4MegaMoEExperts(nn.Module):
return
self._check_runtime_supported()
import vllm.third_party.deep_gemm as deep_gemm
from vllm.utils.deep_gemm import _import_deep_gemm
deep_gemm = _import_deep_gemm()
w13_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
@@ -298,7 +331,9 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self.w2_weight_scale = None
def get_symm_buffer(self):
import vllm.third_party.deep_gemm as deep_gemm
from vllm.utils.deep_gemm import _import_deep_gemm
deep_gemm = _import_deep_gemm()
group = get_ep_group().device_group
device = torch.accelerator.current_device_index()
@@ -324,6 +359,52 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self._symm_buffer_cache[key] = symm_buffer
return symm_buffer
def set_eplb_state(
self,
moe_layer_idx: int,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
self.eplb_state.set_layer_state(
moe_layer_idx,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
def get_expert_weights(self) -> list[torch.Tensor]:
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
def _to_eplb_view(name: str, t: torch.Tensor) -> torch.Tensor:
"""Return a (num_local_experts, -1) view with contiguous memory layout."""
assert t.shape[0] == self.num_local_experts
if t.is_contiguous():
return t.view(self.num_local_experts, -1)
elif t.dim() == 3 and t.stride(1) == 1 and t.stride(2) == t.shape[1]:
# scales have shape (E, M, N) with memory layout (E, N, M)
back = torch.transpose(t, 1, 2)
assert back.is_contiguous()
return back.view(self.num_local_experts, -1)
raise AssertionError(
f"DSv4 EPLB {name}: non-contiguous expert tensor with "
f"unexpected layout shape={tuple(t.shape)} "
f"stride={tuple(t.stride())} dtype={t.dtype}"
)
return [
_to_eplb_view("l1_packed", self._transformed_l1_weights[0]),
_to_eplb_view("l1_scale", self._transformed_l1_weights[1]),
_to_eplb_view("l2_weight", self._transformed_l2_weights[0]),
_to_eplb_view("l2_scale", self._transformed_l2_weights[1]),
]
def update_expert_map(self) -> None:
pass
def forward(
self,
hidden_states: torch.Tensor,
@@ -358,10 +439,27 @@ class DeepseekV4MegaMoEExperts(nn.Module):
activation_clamp: float | None,
fast_math: bool,
) -> None:
import vllm.third_party.deep_gemm as deep_gemm
from vllm.utils.deep_gemm import _import_deep_gemm
deep_gemm = _import_deep_gemm()
symm_buffer = self.get_symm_buffer()
num_tokens = hidden_states.shape[0]
# EPLB: map logical expert IDs to physical replicas and record load.
eplb_state = self.eplb_state
if eplb_state.logical_to_physical_map is not None:
assert eplb_state.expert_load_view is not None
assert eplb_state.logical_replica_count is not None
assert eplb_state.should_record_tensor is not None
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=eplb_state.expert_load_view,
logical_to_physical_map=eplb_state.logical_to_physical_map,
logical_replica_count=eplb_state.logical_replica_count,
record_enabled=eplb_state.should_record_tensor,
)
prepare_megamoe_inputs(
hidden_states,
topk_weights,
@@ -493,17 +591,33 @@ class DeepseekV4MoE(nn.Module):
self.ep_group = get_ep_group()
self.ep_size = self.ep_group.world_size
self.ep_rank = self.ep_group.rank_in_group
assert config.n_routed_experts % self.ep_size == 0
self.n_local_experts = config.n_routed_experts // self.ep_size
self.experts_start_idx = self.ep_rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
eplb_config = vllm_config.parallel_config.eplb_config
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_routed_experts = config.n_routed_experts
self.n_shared_experts = config.n_shared_experts or 0
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
assert self.n_physical_experts % self.ep_size == 0, (
f"n_physical_experts={self.n_physical_experts} must be divisible by "
f"ep_size={self.ep_size}. Adjust num_redundant_experts."
)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
self.n_local_experts = self.n_local_physical_experts
self.experts_start_idx = self.physical_expert_start
self.experts_end_idx = self.physical_expert_end
self.experts = DeepseekV4MegaMoEExperts(
vllm_config,
num_experts=config.n_routed_experts,
num_local_experts=self.n_local_experts,
experts_start_idx=self.experts_start_idx,
num_experts=self.n_physical_experts,
num_local_experts=self.n_local_physical_experts,
experts_start_idx=self.physical_expert_start,
num_logical_experts=self.n_logical_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
@@ -1242,7 +1356,44 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
)
class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
class DeepseekV4MixtureOfExperts(MixtureOfExperts):
moe_mlp_layers: list["DeepseekV4MoE"]
def extract_moe_parameters(self, example_moe: "DeepseekV4MoE | None") -> None:
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
return
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class DeepseekV4ForCausalLM(nn.Module, SupportsPP, DeepseekV4MixtureOfExperts):
model_cls = DeepseekV4Model
# Default mapper assumes the original FP4-expert checkpoint layout.
@@ -1274,6 +1425,28 @@ class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors
)
self.set_moe_parameters()
def set_moe_parameters(self) -> None:
self.expert_weights: MutableSequence[Sequence[torch.Tensor]] = []
self.num_expert_groups = getattr(self.config, "n_group", 1)
self.num_moe_layers = self.config.num_hidden_layers
self.moe_layers: list[nn.Module] = []
self.moe_mlp_layers: list[DeepseekV4MoE] = []
example_moe: DeepseekV4MoE | None = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
if not isinstance(layer, DeepseekV4DecoderLayer):
continue
if isinstance(layer.ffn, DeepseekV4MoE):
example_moe = layer.ffn
self.moe_mlp_layers.append(layer.ffn)
self.moe_layers.append(layer.ffn.experts)
self.num_moe_layers = len(self.moe_layers)
self.extract_moe_parameters(example_moe)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
+1
View File
@@ -140,6 +140,7 @@ _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
@functools.cache
def _import_deep_gemm():
"""Import the deep_gemm module.
+4 -1
View File
@@ -1147,7 +1147,10 @@ def init_worker_distributed_environment(
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance()
override_envs_for_eplb(parallel_config)
override_envs_for_eplb(
parallel_config,
moe_backend=getattr(vllm_config.kernel_config, "moe_backend", None),
)
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_method = distributed_init_method or "env://"