mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-09 12:41:52 +08:00
[https://nvbugspro.nvidia.com/bug/5332927][fix] Fix the bug in the routing unit test (#5065)
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
parent
580a92521e
commit
273c6b9355
@ -17,8 +17,11 @@ def test_default_moe_routing(top_k):
|
||||
|
||||
logits = torch.tensor(
|
||||
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]],
|
||||
dtype=torch.float32)
|
||||
dtype=torch.float32).cuda()
|
||||
indices, scales = routing.apply(logits)
|
||||
indices = indices.cpu()
|
||||
scales = scales.cpu()
|
||||
|
||||
assert indices.shape == (3, top_k)
|
||||
assert scales.shape == (3, top_k)
|
||||
|
||||
@ -26,7 +29,7 @@ def test_default_moe_routing(top_k):
|
||||
assert scales.dtype == torch.float32
|
||||
reference_indices = torch.tensor([[3, 2, 1], [0, 1, 2], [1, 3, 2]],
|
||||
dtype=torch.int32)
|
||||
reference_scales = F.softmax(logits, dim=1)
|
||||
reference_scales = F.softmax(logits, dim=1).cpu()
|
||||
|
||||
# Check that the selected experts are the largest top_k values
|
||||
for i in range(top_k):
|
||||
@ -43,7 +46,6 @@ def test_default_moe_routing(top_k):
|
||||
reference_scales[2, reference_indices[2, i]])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5332927")
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 3])
|
||||
def test_renormalize_moe_routing(top_k):
|
||||
routing = RenormalizeMoeRoutingMethod(top_k=top_k)
|
||||
@ -51,7 +53,7 @@ def test_renormalize_moe_routing(top_k):
|
||||
|
||||
logits = torch.tensor(
|
||||
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]],
|
||||
dtype=torch.float32)
|
||||
dtype=torch.float32).cuda()
|
||||
indices, scales = routing.apply(logits)
|
||||
assert indices.shape == (3, top_k)
|
||||
assert scales.shape == (3, top_k)
|
||||
@ -78,7 +80,7 @@ def gen_unique_logits(num_tokens, num_experts, dtype):
|
||||
return unique_logits.cuda()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 30, 2000])
|
||||
@pytest.mark.parametrize("num_tokens", [30])
|
||||
@pytest.mark.parametrize("top_k", [2, 8])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.bfloat16, torch.float32, torch.float16])
|
||||
@ -110,8 +112,10 @@ def test_sparse_mixer_reference():
|
||||
[2.0, 0.0, -float('inf'), -float('inf')],
|
||||
[0.0, 2.0, -float('inf'), -float('inf')],
|
||||
[1.0, 1.0, 1.0, -float('inf')]],
|
||||
dtype=torch.float32)
|
||||
dtype=torch.float32).cuda()
|
||||
indices, scales = routing.apply(logits.clone())
|
||||
indices = indices.cpu()
|
||||
scales = scales.cpu()
|
||||
|
||||
assert indices.shape == (4, routing.experts_per_token)
|
||||
assert scales.shape == (4, routing.experts_per_token)
|
||||
@ -147,7 +151,7 @@ def test_load_balanced_moe_routing():
|
||||
assert routing.experts_per_token == k
|
||||
|
||||
# Values don't matter for load balanced routing
|
||||
logits = torch.empty((tokens, 4), dtype=torch.float32)
|
||||
logits = torch.empty((tokens, 4), dtype=torch.float32).cuda()
|
||||
|
||||
indices, scales = routing.apply(logits)
|
||||
assert indices.shape == (tokens, k)
|
||||
@ -164,12 +168,14 @@ def test_load_balanced_moe_routing():
|
||||
|
||||
def test_static_moe_routing():
|
||||
routing = StaticMoeRoutingMethod(
|
||||
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32))
|
||||
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda())
|
||||
assert routing.experts_per_token == 4
|
||||
|
||||
logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]],
|
||||
dtype=torch.float32)
|
||||
dtype=torch.float32).cuda()
|
||||
indices, scales = routing.apply(logits)
|
||||
indices = indices.cpu()
|
||||
|
||||
assert scales is None
|
||||
assert indices.shape == (2, 4)
|
||||
assert indices.dtype == torch.int32
|
||||
@ -178,10 +184,12 @@ def test_static_moe_routing():
|
||||
indices, torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32))
|
||||
|
||||
routing = StaticMoeRoutingMethod(
|
||||
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32),
|
||||
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda(),
|
||||
torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]],
|
||||
dtype=torch.float32))
|
||||
dtype=torch.float32).cuda())
|
||||
indices, scales = routing.apply(logits)
|
||||
scales = scales.cpu()
|
||||
|
||||
assert scales is not None
|
||||
assert scales.shape == (2, 4)
|
||||
assert scales.dtype == torch.float32
|
||||
|
||||
Loading…
Reference in New Issue
Block a user