diff --git a/tests/distributed/test_mnnvl_alltoall.py b/tests/distributed/test_mnnvl_alltoall.py index f395c96a3d3..875b65ff084 100644 --- a/tests/distributed/test_mnnvl_alltoall.py +++ b/tests/distributed/test_mnnvl_alltoall.py @@ -342,6 +342,90 @@ def test_one_sided_manager_lifecycle(world_size): ) +# --------------------------------------------------------------------------- +# Test 2b: One-sided manager grows workspace across heterogeneous MoE layers +# --------------------------------------------------------------------------- +# +# Models with heterogeneous MoE quantization — most notably a quantized base +# MoE combined with an unquantized MTP head — can call initialize() multiple +# times with different per-token dispatch payload sizes. The shared workspace +# must grow to the union and the MoeAlltoAll must be rebuilt; otherwise a +# later layer's combine call overruns the workspace sized for the first +# layer's smaller payload and trips FlashInfer's combinePayloadOffset assert. +# --------------------------------------------------------------------------- + + +def _one_sided_workspace_grow_worker(rank, world_size): + from vllm.distributed.device_communicators.all2all import ( + FlashInferNVLinkOneSidedManager, + ) + from vllm.distributed.parallel_state import get_dp_group + + cpu_group = get_dp_group().cpu_group + manager = FlashInferNVLinkOneSidedManager(cpu_group) + + base_kwargs = dict( + max_num_tokens=1024, + top_k=2, + num_experts=world_size * 8, + hidden_size=4096, + ) + nvfp4_kwargs = dict( + dispatch_dtype_bytes_per_elem=0, + dispatch_scale_bytes_per_token=base_kwargs["hidden_size"] // 16, + ) + bf16_kwargs = dict( + dispatch_dtype_bytes_per_elem=2, + dispatch_scale_bytes_per_token=0, + ) + + # First init: NVFP4-like (hidden_bytes = hidden // 2 + hidden // 16). + manager.initialize(**base_kwargs, **nvfp4_kwargs) + assert manager.initialized + nvfp4_workspace_size = manager.workspace_size + nvfp4_moe_alltoall = manager.moe_alltoall + + torch.distributed.barrier() + + # Second init: bf16-like (hidden_bytes = hidden * 2). Models the case of + # a quantized base MoE followed by an unquantized MoE layer (e.g. an MTP + # head). Per-token dispatch payload is ~4x larger, so the union workspace + # must grow and MoeAlltoAll must be rebuilt. + manager.initialize(**base_kwargs, **bf16_kwargs) + assert manager.initialized + assert manager.workspace_size > nvfp4_workspace_size + assert manager.moe_alltoall is not nvfp4_moe_alltoall + bf16_workspace_size = manager.workspace_size + bf16_moe_alltoall = manager.moe_alltoall + + torch.distributed.barrier() + + # Third init: back to NVFP4-like shape. Existing workspace already covers + # it, so initialize() must no-op — no shrink, no rebuild. + manager.initialize(**base_kwargs, **nvfp4_kwargs) + assert manager.initialized + assert manager.workspace_size == bf16_workspace_size + assert manager.moe_alltoall is bf16_moe_alltoall + + torch.distributed.barrier() + manager.cleanup() + + +@requires_multi_gpu +@requires_one_sided +@requires_ptrace +@pytest.mark.parametrize("world_size", [2]) +def test_one_sided_manager_workspace_grow(world_size): + """A later initialize() with a larger per-token payload must grow the + workspace and rebuild MoeAlltoAll; a later initialize() with a smaller + payload must no-op.""" + _spawn_workers( + _one_sided_workspace_grow_worker, + world_size, + dp_size=world_size, + ) + + # --------------------------------------------------------------------------- # Test 3: AgRs dispatch/combine with value validation # --------------------------------------------------------------------------- diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7fd02bce615..8503a0a59e9 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -571,6 +571,10 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase): self.initialized = False self.moe_alltoall: MoeAlltoAll | None = None self.mapping = None + self.workspace_size = 0 + self.max_num_tokens = 0 + self.top_k = 0 + self.num_experts = 0 def initialize( self, @@ -581,9 +585,54 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase): dispatch_dtype_bytes_per_elem: int = 0, dispatch_scale_bytes_per_token: int = 0, ): - """Initialize the MoeAlltoAll workspace.""" + """Initialize (or grow) the MoeAlltoAll workspace.""" + if dispatch_dtype_bytes_per_elem == 0: + hidden_bytes = hidden_size // 2 + else: + hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem + total_dispatch_payload_size_per_token = ( + hidden_bytes + + dispatch_scale_bytes_per_token + + top_k * 4 # int32 topks ids + + top_k * 4 # float32 topk weights + ) + combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states + needed_workspace_size = moe_a2a_get_workspace_size_per_rank( + ep_size=self.world_size, + max_num_tokens=max_num_tokens, + total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token, + combine_payload_size_per_token=combine_payload_size_per_token, + ) + # workspace_size and max_num_tokens are kernel-side max-bounds, so + # heterogeneous MoE layers (e.g. NVFP4 base + bf16 MTP head) only + # need the shared workspace grown to the union. top_k and num_experts + # must match across layers: top_k is a strict-equality assert at + # dispatch (FlashInfer csrc/trtllm_moe_alltoall.cu), and num_experts + # feeds the expert-to-rank routing math, so any mismatch would crash + # or silently corrupt routing. All ranks see the same MoE layers in + # the same order with identical shapes, so the skip / rebuild + # branches are taken consistently across ranks. if self.initialized: - return + assert top_k == self.top_k, ( + "FlashInfer one-sided MoeAlltoAll does not support " + f"heterogeneous top_k across MoE layers (got {top_k}, " + f"was built with {self.top_k})" + ) + assert num_experts == self.num_experts, ( + "FlashInfer one-sided MoeAlltoAll does not support " + f"heterogeneous num_experts across MoE layers (got " + f"{num_experts}, was built with {self.num_experts})" + ) + if ( + needed_workspace_size <= self.workspace_size + and max_num_tokens <= self.max_num_tokens + ): + return + + self.workspace_size = max(self.workspace_size, needed_workspace_size) + self.max_num_tokens = max(self.max_num_tokens, max_num_tokens) + self.top_k = top_k + self.num_experts = num_experts self.cleanup() gpus_per_node = torch.accelerator.device_count() @@ -610,38 +659,17 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase): ep_config = MnnvlConfig( comm_backend=CustomCommunicator(self.cpu_group), ) - if dispatch_dtype_bytes_per_elem == 0: - hidden_bytes = hidden_size // 2 - else: - hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem - total_dispatch_payload_size_per_token = ( - hidden_bytes - + dispatch_scale_bytes_per_token - + top_k * 4 # int32 topks ids - + top_k * 4 # float32 topk weights - ) - combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states - self.workspace_size = moe_a2a_get_workspace_size_per_rank( - ep_size=self.world_size, - max_num_tokens=max_num_tokens, - total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token, - combine_payload_size_per_token=combine_payload_size_per_token, - ) self.moe_alltoall = MoeAlltoAll( mapping=self.mapping, - max_num_tokens=max_num_tokens, - top_k=top_k, - num_experts=num_experts, + max_num_tokens=self.max_num_tokens, + top_k=self.top_k, + num_experts=self.num_experts, workspace_size_per_rank=self.workspace_size, mnnvl_config=ep_config, ) self.gpus_per_node = gpus_per_node - self.max_num_tokens = max_num_tokens - self.top_k = top_k - self.num_experts = num_experts - self.hidden_size = hidden_size self.initialized = True logger.info( @@ -649,7 +677,10 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase): self.rank, self.world_size, ) - dist.barrier() + # Scope barrier to the EP group: with PP, different EP groups can + # rebuild a different number of times if their MoE layers have + # different shape sequences, so a world-level barrier would deadlock. + dist.barrier(group=self.cpu_group) def get_handle(self, kwargs): return self