TensorRT-LLMs/tests/_torch/thop/test_moe_op.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

140 lines
4.4 KiB
Python

import os
import sys
import pytest
import torch
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from utils.util import skip_pre_ampere
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from helpers import reference_moe_torch
@skip_pre_ampere
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_moe_op_profile(dtype):
HIDDEN_SIZE = 64
INTERMEDIATE_SIZE = 32
NUM_EXPERTS = 3
TOP_K = 2
TP_SIZE = 1
TP_RANK = 0
torch.manual_seed(0)
w2_weight = torch.randn((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype).cuda()
profiler = torch.classes.trtllm.FusedMoeProfiler.get_instance(
dtype, dtype, dtype)
# profile
profiler.run_profile(
w2_weight,
TOP_K,
TP_SIZE,
TP_RANK,
[2, 4, 8] # num_tokens_buckets
)
# after profile, check beyond bucket range
bucket_1_profile_ids = profiler.get_profile_ids(1, w2_weight, TOP_K)
bucket_2_profile_ids = profiler.get_profile_ids(2, w2_weight, TOP_K)
assert bucket_1_profile_ids == bucket_2_profile_ids
assert len(bucket_1_profile_ids) == 2
bucket_8_profile_ids = profiler.get_profile_ids(8, w2_weight, TOP_K)
bucket_16_profile_ids = profiler.get_profile_ids(16, w2_weight, TOP_K)
assert bucket_8_profile_ids == bucket_16_profile_ids
assert len(bucket_8_profile_ids) == 2
@skip_pre_ampere
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_moe_op_run(dtype):
SEQ_LEN = 8
HIDDEN_SIZE = 64
INTERMEDIATE_SIZE = 32
NUM_EXPERTS = 3
TOP_K = 2
TP_SIZE = 1
TP_RANK = 0
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
w3_w1_stacked_weight = torch.empty(
(NUM_EXPERTS, INTERMEDIATE_SIZE * 2, HIDDEN_SIZE), dtype=dtype).cuda()
w2_weight = torch.empty((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype).cuda()
weights = {}
for expert_id in range(NUM_EXPERTS):
w1 = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
w2 = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype).cuda()
w3 = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
weights[f"{expert_id}.w1.weight"] = w1
weights[f"{expert_id}.w2.weight"] = w2
weights[f"{expert_id}.w3.weight"] = w3
w3_w1_stacked_weight.data[expert_id].copy_(torch.cat([w3, w1], dim=-2))
w2_weight.data[expert_id].copy_(w2)
# run no profile
with torch.inference_mode():
output_no_profile = torch.ops.trtllm.fused_moe(
x,
router_logits.float(),
w3_w1_stacked_weight,
w2_weight,
dtype,
TOP_K,
quant_scales=None,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
profile_ids=None,
)
# run with profile
profiler = torch.classes.trtllm.FusedMoeProfiler.get_instance(
dtype, dtype, dtype)
profiler.run_profile(
w2_weight,
TOP_K,
TP_SIZE,
TP_RANK,
[2, 4, 8] # num_tokens_buckets
)
profile_ids = profiler.get_profile_ids(SEQ_LEN, w2_weight, TOP_K)
assert len(profile_ids) == 2
with torch.inference_mode():
output_with_profile = torch.ops.trtllm.fused_moe(
x,
router_logits.float(),
w3_w1_stacked_weight,
w2_weight,
dtype,
TOP_K,
quant_scales=None,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
profile_ids=profile_ids,
)
# torch run
with torch.inference_mode():
ref_output = reference_moe_torch(x, router_logits, TOP_K, weights)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output_no_profile,
output_with_profile,
rtol=0.1,
atol=0.1)
torch.testing.assert_close(output_no_profile,
ref_output,
rtol=0.2,
atol=0.5)
torch.testing.assert_close(output_with_profile,
ref_output,
rtol=0.2,
atol=0.5)