[Bugfix][MoE] FlashInfer one-sided: workspace union across heterogeneous layers (#42976)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91
2026-05-19 21:43:04 +03:00
committed by GitHub
parent aed2eb355a
commit f54721bcc3
2 changed files with 142 additions and 27 deletions
+84
View File
@@ -342,6 +342,90 @@ def test_one_sided_manager_lifecycle(world_size):
)
# ---------------------------------------------------------------------------
# Test 2b: One-sided manager grows workspace across heterogeneous MoE layers
# ---------------------------------------------------------------------------
#
# Models with heterogeneous MoE quantization — most notably a quantized base
# MoE combined with an unquantized MTP head — can call initialize() multiple
# times with different per-token dispatch payload sizes. The shared workspace
# must grow to the union and the MoeAlltoAll must be rebuilt; otherwise a
# later layer's combine call overruns the workspace sized for the first
# layer's smaller payload and trips FlashInfer's combinePayloadOffset assert.
# ---------------------------------------------------------------------------
def _one_sided_workspace_grow_worker(rank, world_size):
from vllm.distributed.device_communicators.all2all import (
FlashInferNVLinkOneSidedManager,
)
from vllm.distributed.parallel_state import get_dp_group
cpu_group = get_dp_group().cpu_group
manager = FlashInferNVLinkOneSidedManager(cpu_group)
base_kwargs = dict(
max_num_tokens=1024,
top_k=2,
num_experts=world_size * 8,
hidden_size=4096,
)
nvfp4_kwargs = dict(
dispatch_dtype_bytes_per_elem=0,
dispatch_scale_bytes_per_token=base_kwargs["hidden_size"] // 16,
)
bf16_kwargs = dict(
dispatch_dtype_bytes_per_elem=2,
dispatch_scale_bytes_per_token=0,
)
# First init: NVFP4-like (hidden_bytes = hidden // 2 + hidden // 16).
manager.initialize(**base_kwargs, **nvfp4_kwargs)
assert manager.initialized
nvfp4_workspace_size = manager.workspace_size
nvfp4_moe_alltoall = manager.moe_alltoall
torch.distributed.barrier()
# Second init: bf16-like (hidden_bytes = hidden * 2). Models the case of
# a quantized base MoE followed by an unquantized MoE layer (e.g. an MTP
# head). Per-token dispatch payload is ~4x larger, so the union workspace
# must grow and MoeAlltoAll must be rebuilt.
manager.initialize(**base_kwargs, **bf16_kwargs)
assert manager.initialized
assert manager.workspace_size > nvfp4_workspace_size
assert manager.moe_alltoall is not nvfp4_moe_alltoall
bf16_workspace_size = manager.workspace_size
bf16_moe_alltoall = manager.moe_alltoall
torch.distributed.barrier()
# Third init: back to NVFP4-like shape. Existing workspace already covers
# it, so initialize() must no-op — no shrink, no rebuild.
manager.initialize(**base_kwargs, **nvfp4_kwargs)
assert manager.initialized
assert manager.workspace_size == bf16_workspace_size
assert manager.moe_alltoall is bf16_moe_alltoall
torch.distributed.barrier()
manager.cleanup()
@requires_multi_gpu
@requires_one_sided
@requires_ptrace
@pytest.mark.parametrize("world_size", [2])
def test_one_sided_manager_workspace_grow(world_size):
"""A later initialize() with a larger per-token payload must grow the
workspace and rebuild MoeAlltoAll; a later initialize() with a smaller
payload must no-op."""
_spawn_workers(
_one_sided_workspace_grow_worker,
world_size,
dp_size=world_size,
)
# ---------------------------------------------------------------------------
# Test 3: AgRs dispatch/combine with value validation
# ---------------------------------------------------------------------------
@@ -571,6 +571,10 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
self.initialized = False
self.moe_alltoall: MoeAlltoAll | None = None
self.mapping = None
self.workspace_size = 0
self.max_num_tokens = 0
self.top_k = 0
self.num_experts = 0
def initialize(
self,
@@ -581,9 +585,54 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
dispatch_dtype_bytes_per_elem: int = 0,
dispatch_scale_bytes_per_token: int = 0,
):
"""Initialize the MoeAlltoAll workspace."""
"""Initialize (or grow) the MoeAlltoAll workspace."""
if dispatch_dtype_bytes_per_elem == 0:
hidden_bytes = hidden_size // 2
else:
hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem
total_dispatch_payload_size_per_token = (
hidden_bytes
+ dispatch_scale_bytes_per_token
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
needed_workspace_size = moe_a2a_get_workspace_size_per_rank(
ep_size=self.world_size,
max_num_tokens=max_num_tokens,
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
combine_payload_size_per_token=combine_payload_size_per_token,
)
# workspace_size and max_num_tokens are kernel-side max-bounds, so
# heterogeneous MoE layers (e.g. NVFP4 base + bf16 MTP head) only
# need the shared workspace grown to the union. top_k and num_experts
# must match across layers: top_k is a strict-equality assert at
# dispatch (FlashInfer csrc/trtllm_moe_alltoall.cu), and num_experts
# feeds the expert-to-rank routing math, so any mismatch would crash
# or silently corrupt routing. All ranks see the same MoE layers in
# the same order with identical shapes, so the skip / rebuild
# branches are taken consistently across ranks.
if self.initialized:
return
assert top_k == self.top_k, (
"FlashInfer one-sided MoeAlltoAll does not support "
f"heterogeneous top_k across MoE layers (got {top_k}, "
f"was built with {self.top_k})"
)
assert num_experts == self.num_experts, (
"FlashInfer one-sided MoeAlltoAll does not support "
f"heterogeneous num_experts across MoE layers (got "
f"{num_experts}, was built with {self.num_experts})"
)
if (
needed_workspace_size <= self.workspace_size
and max_num_tokens <= self.max_num_tokens
):
return
self.workspace_size = max(self.workspace_size, needed_workspace_size)
self.max_num_tokens = max(self.max_num_tokens, max_num_tokens)
self.top_k = top_k
self.num_experts = num_experts
self.cleanup()
gpus_per_node = torch.accelerator.device_count()
@@ -610,38 +659,17 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
ep_config = MnnvlConfig(
comm_backend=CustomCommunicator(self.cpu_group),
)
if dispatch_dtype_bytes_per_elem == 0:
hidden_bytes = hidden_size // 2
else:
hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem
total_dispatch_payload_size_per_token = (
hidden_bytes
+ dispatch_scale_bytes_per_token
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
self.workspace_size = moe_a2a_get_workspace_size_per_rank(
ep_size=self.world_size,
max_num_tokens=max_num_tokens,
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
combine_payload_size_per_token=combine_payload_size_per_token,
)
self.moe_alltoall = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_experts=num_experts,
max_num_tokens=self.max_num_tokens,
top_k=self.top_k,
num_experts=self.num_experts,
workspace_size_per_rank=self.workspace_size,
mnnvl_config=ep_config,
)
self.gpus_per_node = gpus_per_node
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.initialized = True
logger.info(
@@ -649,7 +677,10 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
self.rank,
self.world_size,
)
dist.barrier()
# Scope barrier to the EP group: with PP, different EP groups can
# rebuild a different number of times if their MoE layers have
# different shape sequences, so a world-level barrier would deadlock.
dist.barrier(group=self.cpu_group)
def get_handle(self, kwargs):
return self