mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feature] Support EPLB for DeepSeek v4 Mega Moe (#43339)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao (Engrg-Hardware 1) <weizha@login-lyris01.lyris.clusters.nvidia.com>
This commit is contained in:
@@ -61,25 +61,31 @@ class CpuGpuEvent:
|
||||
self._recorded.set()
|
||||
|
||||
|
||||
def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
def override_envs_for_eplb(
|
||||
parallel_config: ParallelConfig,
|
||||
moe_backend: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Override environment variables for EPLB when specific conditions are met.
|
||||
|
||||
Args:
|
||||
parallel_config: The parallel configuration object.
|
||||
moe_backend: The configured MoE backend (e.g. ``deep_gemm_mega_moe``).
|
||||
"""
|
||||
is_data_parallel = parallel_config.data_parallel_size > 1
|
||||
is_eplb_enabled = parallel_config.enable_eplb
|
||||
async_eplb = parallel_config.eplb_config.use_async
|
||||
is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency"
|
||||
is_mega_moe = moe_backend == "deep_gemm_mega_moe"
|
||||
is_nccl_based_eplb_communicator = parallel_config.eplb_config.communicator in (
|
||||
"torch_nccl",
|
||||
"pynccl",
|
||||
)
|
||||
|
||||
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
|
||||
# DeepEP low-latency backend.
|
||||
# Override NCCL_MAX_CTAS to avoid hangs when EPLB's NCCL weight exchange
|
||||
# contends with MoE backend's cooperative-launch on GPU SMs.
|
||||
#
|
||||
# DeepEP low-latency:
|
||||
# The hang happens when two ranks interleave kernel launches differently
|
||||
# between NCCL collectives (used by async EPLB weight exchange) and DeepEP
|
||||
# low-latency (LL) kernels. DeepEP LL uses a cooperative launch and tries
|
||||
@@ -94,12 +100,14 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
|
||||
# cooperative kernel to launch and complete, breaking the deadlock.
|
||||
# See: https://github.com/deepseek-ai/DeepEP/issues/496
|
||||
#
|
||||
# DeepGEMM Mega MoE also uses cooperative launch and will cause hang even
|
||||
# with sync EPLB.
|
||||
if (
|
||||
is_data_parallel
|
||||
and is_eplb_enabled
|
||||
and is_deepep_ll
|
||||
and async_eplb
|
||||
and is_nccl_based_eplb_communicator
|
||||
and ((is_deepep_ll and async_eplb) or is_mega_moe)
|
||||
):
|
||||
current_value_str = os.getenv("NCCL_MAX_CTAS")
|
||||
|
||||
@@ -108,9 +116,10 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
|
||||
override_value = 8
|
||||
os.environ["NCCL_MAX_CTAS"] = str(override_value)
|
||||
backend = "deepep_low_latency" if is_deepep_ll else "deep_gemm_mega_moe"
|
||||
logger.info_once(
|
||||
f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
|
||||
"for expert parallel with NCCL-based EPLB communicator and "
|
||||
"deepep_low_latency backend",
|
||||
f"for expert parallel with NCCL-based EPLB communicator and "
|
||||
f"cooperative MoE backend ({backend})",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Callable, Iterable, MutableSequence, Sequence
|
||||
from itertools import islice
|
||||
|
||||
import regex as re
|
||||
@@ -15,6 +15,7 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.kernels.mhc.tilelang import (
|
||||
hc_head_fused_kernel_tilelang,
|
||||
mhc_fused_post_pre_tilelang,
|
||||
@@ -23,6 +24,9 @@ from vllm.model_executor.kernels.mhc.tilelang import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import (
|
||||
eplb_map_to_physical_and_record,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
fused_topk_bias,
|
||||
)
|
||||
@@ -40,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -144,6 +148,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
prefix: str = "",
|
||||
num_logical_experts: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
@@ -156,6 +161,12 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
self.num_logical_experts = (
|
||||
num_logical_experts if num_logical_experts is not None else num_experts
|
||||
)
|
||||
|
||||
self.eplb_state = EplbLayerState()
|
||||
|
||||
weight_attrs = {"weight_loader": self.weight_loader}
|
||||
self.w13_weight = nn.Parameter(
|
||||
torch.zeros(
|
||||
@@ -206,10 +217,22 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
|
||||
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
|
||||
|
||||
def _map_global_expert_id(self, expert_id: int) -> int:
|
||||
if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx:
|
||||
return -1
|
||||
return expert_id - self.experts_start_idx
|
||||
# Register in the static forward context so the custom-op wrapper
|
||||
# can look up this module by name from within a torch.compile graph.
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def _map_global_expert_id(self, expert_id: int) -> list[int]:
|
||||
"""Return local (per-rank) slot offsets where logical expert
|
||||
`expert_id` should land on this rank.
|
||||
"""
|
||||
physical_ids: list[int] = []
|
||||
for p in range(self.experts_start_idx, self.experts_end_idx):
|
||||
if p % self.num_logical_experts == expert_id:
|
||||
physical_ids.append(p - self.experts_start_idx)
|
||||
return physical_ids
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
@@ -220,30 +243,38 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
expert_id: int,
|
||||
return_success: bool = False,
|
||||
) -> bool | None:
|
||||
local_expert_id = self._map_global_expert_id(expert_id)
|
||||
if local_expert_id == -1:
|
||||
local_expert_ids = self._map_global_expert_id(expert_id)
|
||||
if not local_expert_ids:
|
||||
return False if return_success else None
|
||||
|
||||
expert_data = param.data[local_expert_id]
|
||||
if shard_id in ("w1", "w3"):
|
||||
if "w13_" not in weight_name:
|
||||
return False if return_success else None
|
||||
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
|
||||
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
|
||||
elif shard_id == "w2":
|
||||
if "w2_" not in weight_name:
|
||||
return False if return_success else None
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert shard id: {shard_id}")
|
||||
loaded_any = False
|
||||
for local_expert_id in local_expert_ids:
|
||||
expert_data = param.data[local_expert_id]
|
||||
if shard_id in ("w1", "w3"):
|
||||
if "w13_" not in weight_name:
|
||||
continue
|
||||
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
|
||||
expert_data = expert_data.narrow(
|
||||
0, shard_offset, self.intermediate_size
|
||||
)
|
||||
elif shard_id == "w2":
|
||||
if "w2_" not in weight_name:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert shard id: {shard_id}")
|
||||
|
||||
if expert_data.shape != loaded_weight.shape:
|
||||
raise ValueError(
|
||||
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
|
||||
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
|
||||
f"vs checkpoint {tuple(loaded_weight.shape)}"
|
||||
)
|
||||
expert_data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
if expert_data.shape != loaded_weight.shape:
|
||||
raise ValueError(
|
||||
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
|
||||
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
|
||||
f"vs checkpoint {tuple(loaded_weight.shape)}"
|
||||
)
|
||||
expert_data.copy_(loaded_weight)
|
||||
loaded_any = True
|
||||
|
||||
if return_success:
|
||||
return loaded_any
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
|
||||
@@ -264,7 +295,9 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
return
|
||||
|
||||
self._check_runtime_supported()
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
from vllm.utils.deep_gemm import _import_deep_gemm
|
||||
|
||||
deep_gemm = _import_deep_gemm()
|
||||
|
||||
w13_scale = deep_gemm.transform_sf_into_required_layout(
|
||||
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
|
||||
@@ -298,7 +331,9 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self.w2_weight_scale = None
|
||||
|
||||
def get_symm_buffer(self):
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
from vllm.utils.deep_gemm import _import_deep_gemm
|
||||
|
||||
deep_gemm = _import_deep_gemm()
|
||||
|
||||
group = get_ep_group().device_group
|
||||
device = torch.accelerator.current_device_index()
|
||||
@@ -324,6 +359,52 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self._symm_buffer_cache[key] = symm_buffer
|
||||
return symm_buffer
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
moe_layer_idx: int,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
self.eplb_state.set_layer_state(
|
||||
moe_layer_idx,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
def get_expert_weights(self) -> list[torch.Tensor]:
|
||||
self.finalize_weights()
|
||||
assert self._transformed_l1_weights is not None
|
||||
assert self._transformed_l2_weights is not None
|
||||
|
||||
def _to_eplb_view(name: str, t: torch.Tensor) -> torch.Tensor:
|
||||
"""Return a (num_local_experts, -1) view with contiguous memory layout."""
|
||||
assert t.shape[0] == self.num_local_experts
|
||||
if t.is_contiguous():
|
||||
return t.view(self.num_local_experts, -1)
|
||||
elif t.dim() == 3 and t.stride(1) == 1 and t.stride(2) == t.shape[1]:
|
||||
# scales have shape (E, M, N) with memory layout (E, N, M)
|
||||
back = torch.transpose(t, 1, 2)
|
||||
assert back.is_contiguous()
|
||||
return back.view(self.num_local_experts, -1)
|
||||
|
||||
raise AssertionError(
|
||||
f"DSv4 EPLB {name}: non-contiguous expert tensor with "
|
||||
f"unexpected layout shape={tuple(t.shape)} "
|
||||
f"stride={tuple(t.stride())} dtype={t.dtype}"
|
||||
)
|
||||
|
||||
return [
|
||||
_to_eplb_view("l1_packed", self._transformed_l1_weights[0]),
|
||||
_to_eplb_view("l1_scale", self._transformed_l1_weights[1]),
|
||||
_to_eplb_view("l2_weight", self._transformed_l2_weights[0]),
|
||||
_to_eplb_view("l2_scale", self._transformed_l2_weights[1]),
|
||||
]
|
||||
|
||||
def update_expert_map(self) -> None:
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -358,10 +439,27 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
activation_clamp: float | None,
|
||||
fast_math: bool,
|
||||
) -> None:
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
from vllm.utils.deep_gemm import _import_deep_gemm
|
||||
|
||||
deep_gemm = _import_deep_gemm()
|
||||
|
||||
symm_buffer = self.get_symm_buffer()
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# EPLB: map logical expert IDs to physical replicas and record load.
|
||||
eplb_state = self.eplb_state
|
||||
if eplb_state.logical_to_physical_map is not None:
|
||||
assert eplb_state.expert_load_view is not None
|
||||
assert eplb_state.logical_replica_count is not None
|
||||
assert eplb_state.should_record_tensor is not None
|
||||
topk_ids = eplb_map_to_physical_and_record(
|
||||
topk_ids=topk_ids,
|
||||
expert_load_view=eplb_state.expert_load_view,
|
||||
logical_to_physical_map=eplb_state.logical_to_physical_map,
|
||||
logical_replica_count=eplb_state.logical_replica_count,
|
||||
record_enabled=eplb_state.should_record_tensor,
|
||||
)
|
||||
|
||||
prepare_megamoe_inputs(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
@@ -493,17 +591,33 @@ class DeepseekV4MoE(nn.Module):
|
||||
self.ep_group = get_ep_group()
|
||||
self.ep_size = self.ep_group.world_size
|
||||
self.ep_rank = self.ep_group.rank_in_group
|
||||
assert config.n_routed_experts % self.ep_size == 0
|
||||
|
||||
self.n_local_experts = config.n_routed_experts // self.ep_size
|
||||
self.experts_start_idx = self.ep_rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.n_shared_experts = config.n_shared_experts or 0
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
assert self.n_physical_experts % self.ep_size == 0, (
|
||||
f"n_physical_experts={self.n_physical_experts} must be divisible by "
|
||||
f"ep_size={self.ep_size}. Adjust num_redundant_experts."
|
||||
)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
self.n_local_experts = self.n_local_physical_experts
|
||||
self.experts_start_idx = self.physical_expert_start
|
||||
self.experts_end_idx = self.physical_expert_end
|
||||
|
||||
self.experts = DeepseekV4MegaMoEExperts(
|
||||
vllm_config,
|
||||
num_experts=config.n_routed_experts,
|
||||
num_local_experts=self.n_local_experts,
|
||||
experts_start_idx=self.experts_start_idx,
|
||||
num_experts=self.n_physical_experts,
|
||||
num_local_experts=self.n_local_physical_experts,
|
||||
experts_start_idx=self.physical_expert_start,
|
||||
num_logical_experts=self.n_logical_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
@@ -1242,7 +1356,44 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
|
||||
class DeepseekV4MixtureOfExperts(MixtureOfExperts):
|
||||
moe_mlp_layers: list["DeepseekV4MoE"]
|
||||
|
||||
def extract_moe_parameters(self, example_moe: "DeepseekV4MoE | None") -> None:
|
||||
if example_moe is None:
|
||||
self.num_moe_layers = 0
|
||||
self.num_expert_groups = 0
|
||||
self.num_logical_experts = 0
|
||||
self.num_physical_experts = 0
|
||||
self.num_local_physical_experts = 0
|
||||
self.num_routed_experts = 0
|
||||
self.num_shared_experts = 0
|
||||
self.num_redundant_experts = 0
|
||||
return
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for moe in self.moe_mlp_layers:
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
|
||||
class DeepseekV4ForCausalLM(nn.Module, SupportsPP, DeepseekV4MixtureOfExperts):
|
||||
model_cls = DeepseekV4Model
|
||||
|
||||
# Default mapper assumes the original FP4-expert checkpoint layout.
|
||||
@@ -1274,6 +1425,28 @@ class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
self.set_moe_parameters()
|
||||
|
||||
def set_moe_parameters(self) -> None:
|
||||
self.expert_weights: MutableSequence[Sequence[torch.Tensor]] = []
|
||||
self.num_expert_groups = getattr(self.config, "n_group", 1)
|
||||
self.num_moe_layers = self.config.num_hidden_layers
|
||||
self.moe_layers: list[nn.Module] = []
|
||||
self.moe_mlp_layers: list[DeepseekV4MoE] = []
|
||||
example_moe: DeepseekV4MoE | None = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
if not isinstance(layer, DeepseekV4DecoderLayer):
|
||||
continue
|
||||
if isinstance(layer.ffn, DeepseekV4MoE):
|
||||
example_moe = layer.ffn
|
||||
self.moe_mlp_layers.append(layer.ffn)
|
||||
self.moe_layers.append(layer.ffn.experts)
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _import_deep_gemm():
|
||||
"""Import the deep_gemm module.
|
||||
|
||||
|
||||
@@ -1147,7 +1147,10 @@ def init_worker_distributed_environment(
|
||||
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
|
||||
|
||||
init_batch_invariance()
|
||||
override_envs_for_eplb(parallel_config)
|
||||
override_envs_for_eplb(
|
||||
parallel_config,
|
||||
moe_backend=getattr(vllm_config.kernel_config, "moe_backend", None),
|
||||
)
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_method = distributed_init_method or "env://"
|
||||
|
||||
Reference in New Issue
Block a user