Chore: Remove profile test. (#3565)

Because it is duplicated with test_fp4_linear. Also, cpp profiler has been unified with the new AutoTuner already.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-04-15 14:17:51 +08:00 committed by GitHub
parent 0305942808
commit cfc6f242dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -206,61 +206,3 @@ class TestFunctional(unittest.TestCase):
if not use_ue8m0:
# The gap is too large for ue8m0, so we just make sure that it runs
self.assertTrue(torch.allclose(a_pt, aq_fp32, atol=1, rtol=0))
class TestProfiling(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level("warning")
torch.manual_seed(42)
torch.cuda.manual_seed(42)
@parameterized.expand(
list([
[1024, 1024, 1024],
[512, 32, 64],
[7, 32, 32],
]),
name_func=unittest_name_func,
)
@skip_pre_blackwell_unittest
def test_fp4_quantize_gemm_torch_profiling(self, m: int, n: int, k: int):
pytest.skip("https://nvbugs/5100633")
a = torch.randn([m, k], dtype=torch.float32)
b = torch.randn([n, k], dtype=torch.float32)
a_global_sf = (448 * 6) / a.abs().max().float()
b_global_sf = (448 * 6) / b.abs().max().float()
ab_global_sf = 1 / (a_global_sf * b_global_sf)
ab_global_sf = ab_global_sf.cuda()
profiler = torch.classes.trtllm.FP4GemmRunner.get_instance(torch.half)
buckets = [1, 16, 32, 48, 64, 1024, 2048, 4096]
profiler.run_profile(n, k, buckets)
sf_vec_size = 16
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a.half().cuda(),
a_global_sf.cuda(),
sf_vec_size, False)
b_fp4, b_sf = torch.ops.trtllm.fp4_quantize(b.half().cuda(),
b_global_sf.cuda(),
sf_vec_size, False)
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
1 / a_global_sf,
sf_vec_size)
torch.cuda.synchronize()
b_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(b_fp4.cpu(), b_sf.cpu(),
1 / b_global_sf,
sf_vec_size)
c_ref = torch.ops.trtllm.fp4_gemm(a_fp4, b_fp4, a_sf, b_sf,
ab_global_sf, False)
best_config_idx = profiler.get_best_config_id(m, n, k)
c_actual = profiler.run_gemm(a_fp4, b_fp4, a_sf, b_sf, ab_global_sf,
False, best_config_idx)
torch.cuda.synchronize()
torch.testing.assert_close(c_actual, c_ref, atol=1e-2, rtol=0)