mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
4d51588e23
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: qizixi <zixi@inferact.ai> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Yongye Zhu <yongye@inferact.ai> Co-authored-by: Simon Mo <simon@inferact.ai> Co-authored-by: Bugen Zhao <i@bugenzhao.com> Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roy Wang <yasong.wang@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Zhewen Li <jerven.vllm@gmail.com> Co-authored-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: khluu <khluu000@gmail.com> Co-authored-by: qizixi <zixi@inferact.ai> Co-authored-by: Zhewen Li <zhewenli@inferact.ai>
246 lines
8.1 KiB
Python
246 lines
8.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import types
|
|
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
|
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
|
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
|
|
RoutedExpertsCapturer,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
|
|
|
pytestmark = pytest.mark.cpu_test
|
|
|
|
_REC_MODULE = "vllm.model_executor.layers.fused_moe.routed_experts_capturer"
|
|
|
|
|
|
def _capturer_with_buffer(
|
|
*,
|
|
max_tokens: int = 8,
|
|
num_layers: int = 4,
|
|
num_experts_per_tok: int = 2,
|
|
dp_rank: int = 0,
|
|
) -> RoutedExpertsCapturer:
|
|
c = RoutedExpertsCapturer()
|
|
c.dp_rank = dp_rank
|
|
c._device_buffer = torch.full(
|
|
(max_tokens, num_layers, num_experts_per_tok),
|
|
-1,
|
|
dtype=torch.int32,
|
|
)
|
|
return c
|
|
|
|
|
|
class DummyRouter(BaseRouter):
|
|
@property
|
|
def routing_method_type(self) -> RoutingMethodType:
|
|
return RoutingMethodType.FUSED_TOPK
|
|
|
|
def _compute_routing(
|
|
self, hidden_states, router_logits, indices_type, *, input_ids=None
|
|
):
|
|
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
|
|
topk_weights = torch.ones_like(topk_ids, dtype=torch.float32)
|
|
return topk_weights, topk_ids
|
|
|
|
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
|
# Make mapping observable without requiring CUDA EPLB path.
|
|
return topk_ids + 10
|
|
|
|
|
|
def _make_router() -> DummyRouter:
|
|
return DummyRouter(
|
|
top_k=2,
|
|
global_num_experts=16,
|
|
eplb_state=EplbLayerState(),
|
|
enable_eplb=False,
|
|
indices_type_getter=None,
|
|
)
|
|
|
|
|
|
def test_base_router_capture_pre_eplb_mapping():
|
|
router = _make_router()
|
|
captured = []
|
|
|
|
def capture_fn(ids):
|
|
captured.append(ids.clone())
|
|
|
|
router.set_capture_fn(capture_fn)
|
|
topk_weights, topk_ids = router.select_experts(
|
|
hidden_states=torch.empty(1),
|
|
router_logits=torch.empty(1),
|
|
)
|
|
|
|
assert topk_weights.shape == topk_ids.shape
|
|
assert len(captured) == 1
|
|
assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]]))
|
|
assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]]))
|
|
|
|
|
|
def test_base_router_capture_with_eplb_enabled():
|
|
router = _make_router()
|
|
router.enable_eplb = True
|
|
router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64)
|
|
router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1)
|
|
router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64)
|
|
router.eplb_state.should_record_tensor = torch.ones((), dtype=torch.bool)
|
|
|
|
captured = []
|
|
|
|
def capture_fn(ids):
|
|
captured.append(ids.clone())
|
|
|
|
router.set_capture_fn(capture_fn)
|
|
_, topk_ids = router.select_experts(
|
|
hidden_states=torch.empty(1),
|
|
router_logits=torch.empty(1),
|
|
)
|
|
|
|
assert len(captured) == 1
|
|
# Capture should see logical ids pre-EPLB mapping.
|
|
assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]]))
|
|
# Our DummyRouter mapping adds +10.
|
|
assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]]))
|
|
|
|
|
|
def test_gpu_model_runner_binds_router_capture(monkeypatch):
|
|
from vllm.v1.worker import gpu_model_runner as gmr
|
|
|
|
class DummyFusedMoE:
|
|
def __init__(self):
|
|
self.layer_id = 7
|
|
self.router = _make_router()
|
|
|
|
class DummyCapturer:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def capture(self, layer_id, topk_ids):
|
|
self.calls.append((layer_id, topk_ids))
|
|
|
|
dummy_module = DummyFusedMoE()
|
|
|
|
# Patch the runtime import inside _bind_routed_experts_capturer.
|
|
import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer
|
|
|
|
monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)
|
|
|
|
dummy_self = types.SimpleNamespace(
|
|
compilation_config=types.SimpleNamespace(
|
|
static_forward_context={"dummy": dummy_module}
|
|
)
|
|
)
|
|
|
|
capturer = DummyCapturer()
|
|
gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer)
|
|
|
|
assert dummy_module.router.capture_fn is not None
|
|
dummy_module.router.capture_fn(torch.tensor([[5, 6]]))
|
|
|
|
assert len(capturer.calls) == 1
|
|
layer_id, topk_ids = capturer.calls[0]
|
|
assert layer_id == 7
|
|
assert torch.equal(topk_ids, torch.tensor([[5, 6]]))
|
|
|
|
|
|
def test_gpu_model_runner_binding_stage(monkeypatch):
|
|
from vllm.v1.worker import gpu_model_runner as gmr
|
|
|
|
class DummyFusedMoE:
|
|
def __init__(self):
|
|
self.layer_id = 11
|
|
self.router = _make_router()
|
|
|
|
class DummyCapturer:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def capture(self, layer_id, topk_ids):
|
|
self.calls.append((layer_id, topk_ids))
|
|
|
|
dummy_module = DummyFusedMoE()
|
|
|
|
import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer
|
|
|
|
monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)
|
|
|
|
dummy_self = types.SimpleNamespace(
|
|
compilation_config=types.SimpleNamespace(
|
|
static_forward_context={"dummy": dummy_module}
|
|
)
|
|
)
|
|
|
|
# Before binding, no capture hook.
|
|
assert dummy_module.router.capture_fn is None
|
|
|
|
capturer = DummyCapturer()
|
|
gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer)
|
|
|
|
# After binding, hook should exist and be callable.
|
|
assert callable(dummy_module.router.capture_fn)
|
|
dummy_module.router.capture_fn(torch.tensor([[9, 10]]))
|
|
assert len(capturer.calls) == 1
|
|
|
|
|
|
def test_routed_experts_capturer_single_dp_no_metadata():
|
|
"""dp_metadata is None: capture writes the full topk_ids rows."""
|
|
capturer = _capturer_with_buffer(dp_rank=0)
|
|
topk = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
|
|
ctx = SimpleNamespace(dp_metadata=None)
|
|
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
|
|
capturer.capture(layer_id=0, topk_ids=topk)
|
|
assert torch.equal(capturer._device_buffer[:3, 0, :], topk)
|
|
assert capturer._device_buffer[3, 0, 0].item() == -1
|
|
|
|
|
|
def test_routed_experts_capturer_dp_naive_concatenated_all_ranks():
|
|
"""n == sum(num_tokens_dp): slice this rank's segment from concatenated topk."""
|
|
capturer = _capturer_with_buffer(dp_rank=1)
|
|
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
|
|
ctx = SimpleNamespace(
|
|
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
|
|
)
|
|
# Concatenated order: rank0 rows then rank1 rows.
|
|
topk = torch.tensor(
|
|
[[0, 1], [2, 3], [10, 11], [12, 13], [14, 15]], dtype=torch.int32
|
|
)
|
|
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
|
|
capturer.capture(layer_id=0, topk_ids=topk)
|
|
want = topk[2:5]
|
|
assert torch.equal(capturer._device_buffer[:3, 0, :], want)
|
|
|
|
|
|
def test_routed_experts_capturer_dp_modular_local_tokens():
|
|
"""n == token_num_per_dp: topk is already local to this DP rank."""
|
|
capturer = _capturer_with_buffer(dp_rank=1)
|
|
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
|
|
ctx = SimpleNamespace(
|
|
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
|
|
)
|
|
topk = torch.tensor([[10, 11], [12, 13], [14, 15]], dtype=torch.int32)
|
|
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
|
|
capturer.capture(layer_id=0, topk_ids=topk)
|
|
assert torch.equal(capturer._device_buffer[:3, 0, :], topk)
|
|
|
|
|
|
def test_routed_experts_capturer_dp_unexpected_batch_raises():
|
|
"""Mismatch between topk batch dim and DP layout: fail fast."""
|
|
capturer = _capturer_with_buffer(dp_rank=0)
|
|
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
|
|
ctx = SimpleNamespace(
|
|
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
|
|
)
|
|
# total=5, local=2: n=1 matches neither naive (5) nor modular (2).
|
|
topk = torch.tensor([[1, 2]], dtype=torch.int32)
|
|
with (
|
|
patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx),
|
|
pytest.raises(AssertionError, match="unexpected topk_ids batch dim"),
|
|
):
|
|
capturer.capture(layer_id=0, topk_ids=topk)
|
|
assert capturer._device_buffer[0, 0, 0].item() == -1
|