[https://nvbugspro.nvidia.com/bug/5332927][fix] Fix the bug in the routing unit test (#5065)

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
ChristinaZ 2025-06-11 09:44:35 +08:00 committed by GitHub
parent 580a92521e
commit 273c6b9355
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,8 +17,11 @@ def test_default_moe_routing(top_k):
logits = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
indices = indices.cpu()
scales = scales.cpu()
assert indices.shape == (3, top_k)
assert scales.shape == (3, top_k)
@ -26,7 +29,7 @@ def test_default_moe_routing(top_k):
assert scales.dtype == torch.float32
reference_indices = torch.tensor([[3, 2, 1], [0, 1, 2], [1, 3, 2]],
dtype=torch.int32)
reference_scales = F.softmax(logits, dim=1)
reference_scales = F.softmax(logits, dim=1).cpu()
# Check that the selected experts are the largest top_k values
for i in range(top_k):
@ -43,7 +46,6 @@ def test_default_moe_routing(top_k):
reference_scales[2, reference_indices[2, i]])
@pytest.mark.skip(reason="https://nvbugs/5332927")
@pytest.mark.parametrize("top_k", [1, 2, 3])
def test_renormalize_moe_routing(top_k):
routing = RenormalizeMoeRoutingMethod(top_k=top_k)
@ -51,7 +53,7 @@ def test_renormalize_moe_routing(top_k):
logits = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
assert indices.shape == (3, top_k)
assert scales.shape == (3, top_k)
@ -78,7 +80,7 @@ def gen_unique_logits(num_tokens, num_experts, dtype):
return unique_logits.cuda()
@pytest.mark.parametrize("num_tokens", [1, 30, 2000])
@pytest.mark.parametrize("num_tokens", [30])
@pytest.mark.parametrize("top_k", [2, 8])
@pytest.mark.parametrize("dtype",
[torch.bfloat16, torch.float32, torch.float16])
@ -110,8 +112,10 @@ def test_sparse_mixer_reference():
[2.0, 0.0, -float('inf'), -float('inf')],
[0.0, 2.0, -float('inf'), -float('inf')],
[1.0, 1.0, 1.0, -float('inf')]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits.clone())
indices = indices.cpu()
scales = scales.cpu()
assert indices.shape == (4, routing.experts_per_token)
assert scales.shape == (4, routing.experts_per_token)
@ -147,7 +151,7 @@ def test_load_balanced_moe_routing():
assert routing.experts_per_token == k
# Values don't matter for load balanced routing
logits = torch.empty((tokens, 4), dtype=torch.float32)
logits = torch.empty((tokens, 4), dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
assert indices.shape == (tokens, k)
@ -164,12 +168,14 @@ def test_load_balanced_moe_routing():
def test_static_moe_routing():
routing = StaticMoeRoutingMethod(
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32))
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda())
assert routing.experts_per_token == 4
logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
indices = indices.cpu()
assert scales is None
assert indices.shape == (2, 4)
assert indices.dtype == torch.int32
@ -178,10 +184,12 @@ def test_static_moe_routing():
indices, torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32))
routing = StaticMoeRoutingMethod(
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32),
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda(),
torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]],
dtype=torch.float32))
dtype=torch.float32).cuda())
indices, scales = routing.apply(logits)
scales = scales.cpu()
assert scales is not None
assert scales.shape == (2, 4)
assert scales.dtype == torch.float32