TensorRT-LLMs/tests/_torch/test_fp4_gemm_quantize.py
2025-04-02 17:01:16 +08:00

195 lines
7.8 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import pytest
import torch
from parameterized import parameterized
import tensorrt_llm
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from utils.util import skip_pre_blackwell_unittest, unittest_name_func
# Used by the (fp16 -> int4) quant layer + int4 gemm network.
def e2m1_and_ufp8_scale_to_float_tensor_v2(
e2m1_tensor: torch.Tensor,
ufp8_scale_tensor: torch.Tensor,
global_scale_tensor: torch.Tensor,
sf_vec_size,
ufp8_type: int = 1,
):
float_tensor = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2(
e2m1_tensor, ufp8_scale_tensor, global_scale_tensor, sf_vec_size,
ufp8_type)
return float_tensor
class TestFunctional(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level("warning")
torch.manual_seed(42)
torch.cuda.manual_seed(42)
@parameterized.expand(
list([
[1024, 1024, 1024],
[7, 32, 32],
]),
name_func=unittest_name_func,
)
@skip_pre_blackwell_unittest
def test_fp4_quantize_gemm_torch(self, m, n, k):
pytest.skip("https://nvbugs/5100633")
a = torch.randn([m, k], dtype=torch.float32)
b = torch.randn([n, k], dtype=torch.float32)
a_global_sf = (448 * 6) / a.abs().max().float()
b_global_sf = (448 * 6) / b.abs().max().float()
ab_global_sf = 1 / (a_global_sf * b_global_sf)
ab_global_sf = ab_global_sf.cuda()
sf_vec_size = 16
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a.half().cuda(),
a_global_sf.cuda(),
sf_vec_size, False)
b_fp4, b_sf = torch.ops.trtllm.fp4_quantize(b.half().cuda(),
b_global_sf.cuda(),
sf_vec_size, False)
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
1 / a_global_sf,
sf_vec_size)
b_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(b_fp4.cpu(), b_sf.cpu(),
1 / b_global_sf,
sf_vec_size)
c = (torch.ops.trtllm.fp4_gemm(a_fp4, b_fp4, a_sf, b_sf, ab_global_sf,
False).float().cpu())
torch.cuda.synchronize()
c_pt = torch.nn.functional.linear(a_pt, b_pt)
self.assertTrue(torch.allclose(c_pt, c, atol=1e-2, rtol=1e-2))
@parameterized.expand(list([[1024, 1024, torch.half, False],
[2, 512, torch.bfloat16, False],
[13, 16, torch.half, True]]),
name_func=unittest_name_func)
@skip_pre_blackwell_unittest
def test_fp4_quantize_torch(self, m, k, dtype, use_ue8m0):
a = torch.randn([m, k], dtype=torch.float32).to(dtype).float()
a_global_sf = (448 * 6) / a.abs().max().float()
sf_vec_size = 16
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(
a.to(dtype).cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0)
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
1 / a_global_sf,
sf_vec_size)
torch.cuda.synchronize()
if not use_ue8m0:
# The gap is too large for ue8m0, so we just make sure that it runs
self.assertTrue(torch.allclose(a_pt, a, atol=1, rtol=0))
@parameterized.expand(list([[64, 64, torch.float8_e4m3fn, False],
[13, 16, torch.float8_e4m3fn, True]]),
name_func=unittest_name_func)
@skip_pre_blackwell_unittest
def test_fp4_quantize_torch_fp8(self, m, k, dtype, use_ue8m0):
assert dtype == torch.float8_e4m3fn
a = torch.randn([m, k], dtype=torch.float32)
amax = a.abs().max().float()
a_fp8 = (a / amax * 448).to(dtype)
aq_fp32 = a_fp8.float() * amax / 448
a_global_sf = (448 * 6) / amax
sf_vec_size = 16
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a_fp8.cuda(),
a_global_sf.cuda(),
sf_vec_size, use_ue8m0)
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
1 / a_global_sf,
sf_vec_size)
torch.cuda.synchronize()
if not use_ue8m0:
# The gap is too large for ue8m0, so we just make sure that it runs
self.assertTrue(torch.allclose(a_pt, aq_fp32, atol=1, rtol=0))
class TestProfiling(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level("warning")
torch.manual_seed(42)
torch.cuda.manual_seed(42)
@parameterized.expand(
list([
[1024, 1024, 1024],
[512, 32, 64],
[7, 32, 32],
]),
name_func=unittest_name_func,
)
@skip_pre_blackwell_unittest
def test_fp4_quantize_gemm_torch_profiling(self, m: int, n: int, k: int):
a = torch.randn([m, k], dtype=torch.float32)
b = torch.randn([n, k], dtype=torch.float32)
a_global_sf = (448 * 6) / a.abs().max().float()
b_global_sf = (448 * 6) / b.abs().max().float()
ab_global_sf = 1 / (a_global_sf * b_global_sf)
ab_global_sf = ab_global_sf.cuda()
profiler = torch.classes.trtllm.FP4GemmRunner.get_instance(torch.half)
buckets = [1, 16, 32, 48, 64, 1024, 2048, 4096]
profiler.run_profile(n, k, buckets)
sf_vec_size = 16
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a.half().cuda(),
a_global_sf.cuda(),
sf_vec_size, False)
b_fp4, b_sf = torch.ops.trtllm.fp4_quantize(b.half().cuda(),
b_global_sf.cuda(),
sf_vec_size, False)
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
1 / a_global_sf,
sf_vec_size)
torch.cuda.synchronize()
b_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(b_fp4.cpu(), b_sf.cpu(),
1 / b_global_sf,
sf_vec_size)
c_ref = torch.ops.trtllm.fp4_gemm(a_fp4, b_fp4, a_sf, b_sf,
ab_global_sf, False)
best_config_idx = profiler.get_best_config_id(m, n, k)
c_actual = profiler.run_gemm(a_fp4, b_fp4, a_sf, b_sf, ab_global_sf,
False, best_config_idx)
torch.cuda.synchronize()
torch.testing.assert_close(c_actual, c_ref, atol=1e-2, rtol=0)