mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
195 lines
7.8 KiB
Python
195 lines
7.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import pytest
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|
from utils.util import skip_pre_blackwell_unittest, unittest_name_func
|
|
|
|
|
|
# Used by the (fp16 -> int4) quant layer + int4 gemm network.
|
|
def e2m1_and_ufp8_scale_to_float_tensor_v2(
|
|
e2m1_tensor: torch.Tensor,
|
|
ufp8_scale_tensor: torch.Tensor,
|
|
global_scale_tensor: torch.Tensor,
|
|
sf_vec_size,
|
|
ufp8_type: int = 1,
|
|
):
|
|
float_tensor = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2(
|
|
e2m1_tensor, ufp8_scale_tensor, global_scale_tensor, sf_vec_size,
|
|
ufp8_type)
|
|
return float_tensor
|
|
|
|
|
|
class TestFunctional(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level("warning")
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed(42)
|
|
|
|
@parameterized.expand(
|
|
list([
|
|
[1024, 1024, 1024],
|
|
[7, 32, 32],
|
|
]),
|
|
name_func=unittest_name_func,
|
|
)
|
|
@skip_pre_blackwell_unittest
|
|
def test_fp4_quantize_gemm_torch(self, m, n, k):
|
|
pytest.skip("https://nvbugs/5100633")
|
|
a = torch.randn([m, k], dtype=torch.float32)
|
|
b = torch.randn([n, k], dtype=torch.float32)
|
|
a_global_sf = (448 * 6) / a.abs().max().float()
|
|
b_global_sf = (448 * 6) / b.abs().max().float()
|
|
ab_global_sf = 1 / (a_global_sf * b_global_sf)
|
|
ab_global_sf = ab_global_sf.cuda()
|
|
|
|
sf_vec_size = 16
|
|
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a.half().cuda(),
|
|
a_global_sf.cuda(),
|
|
sf_vec_size, False)
|
|
b_fp4, b_sf = torch.ops.trtllm.fp4_quantize(b.half().cuda(),
|
|
b_global_sf.cuda(),
|
|
sf_vec_size, False)
|
|
|
|
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
|
|
1 / a_global_sf,
|
|
sf_vec_size)
|
|
b_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(b_fp4.cpu(), b_sf.cpu(),
|
|
1 / b_global_sf,
|
|
sf_vec_size)
|
|
|
|
c = (torch.ops.trtllm.fp4_gemm(a_fp4, b_fp4, a_sf, b_sf, ab_global_sf,
|
|
False).float().cpu())
|
|
|
|
torch.cuda.synchronize()
|
|
c_pt = torch.nn.functional.linear(a_pt, b_pt)
|
|
self.assertTrue(torch.allclose(c_pt, c, atol=1e-2, rtol=1e-2))
|
|
|
|
@parameterized.expand(list([[1024, 1024, torch.half, False],
|
|
[2, 512, torch.bfloat16, False],
|
|
[13, 16, torch.half, True]]),
|
|
name_func=unittest_name_func)
|
|
@skip_pre_blackwell_unittest
|
|
def test_fp4_quantize_torch(self, m, k, dtype, use_ue8m0):
|
|
a = torch.randn([m, k], dtype=torch.float32).to(dtype).float()
|
|
a_global_sf = (448 * 6) / a.abs().max().float()
|
|
sf_vec_size = 16
|
|
|
|
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(
|
|
a.to(dtype).cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0)
|
|
|
|
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
|
|
1 / a_global_sf,
|
|
sf_vec_size)
|
|
|
|
torch.cuda.synchronize()
|
|
if not use_ue8m0:
|
|
# The gap is too large for ue8m0, so we just make sure that it runs
|
|
self.assertTrue(torch.allclose(a_pt, a, atol=1, rtol=0))
|
|
|
|
@parameterized.expand(list([[64, 64, torch.float8_e4m3fn, False],
|
|
[13, 16, torch.float8_e4m3fn, True]]),
|
|
name_func=unittest_name_func)
|
|
@skip_pre_blackwell_unittest
|
|
def test_fp4_quantize_torch_fp8(self, m, k, dtype, use_ue8m0):
|
|
assert dtype == torch.float8_e4m3fn
|
|
a = torch.randn([m, k], dtype=torch.float32)
|
|
amax = a.abs().max().float()
|
|
a_fp8 = (a / amax * 448).to(dtype)
|
|
aq_fp32 = a_fp8.float() * amax / 448
|
|
a_global_sf = (448 * 6) / amax
|
|
sf_vec_size = 16
|
|
|
|
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a_fp8.cuda(),
|
|
a_global_sf.cuda(),
|
|
sf_vec_size, use_ue8m0)
|
|
|
|
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
|
|
1 / a_global_sf,
|
|
sf_vec_size)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
if not use_ue8m0:
|
|
# The gap is too large for ue8m0, so we just make sure that it runs
|
|
self.assertTrue(torch.allclose(a_pt, aq_fp32, atol=1, rtol=0))
|
|
|
|
|
|
class TestProfiling(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level("warning")
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed(42)
|
|
|
|
@parameterized.expand(
|
|
list([
|
|
[1024, 1024, 1024],
|
|
[512, 32, 64],
|
|
[7, 32, 32],
|
|
]),
|
|
name_func=unittest_name_func,
|
|
)
|
|
@skip_pre_blackwell_unittest
|
|
def test_fp4_quantize_gemm_torch_profiling(self, m: int, n: int, k: int):
|
|
a = torch.randn([m, k], dtype=torch.float32)
|
|
b = torch.randn([n, k], dtype=torch.float32)
|
|
a_global_sf = (448 * 6) / a.abs().max().float()
|
|
b_global_sf = (448 * 6) / b.abs().max().float()
|
|
ab_global_sf = 1 / (a_global_sf * b_global_sf)
|
|
ab_global_sf = ab_global_sf.cuda()
|
|
|
|
profiler = torch.classes.trtllm.FP4GemmRunner.get_instance(torch.half)
|
|
buckets = [1, 16, 32, 48, 64, 1024, 2048, 4096]
|
|
profiler.run_profile(n, k, buckets)
|
|
|
|
sf_vec_size = 16
|
|
a_fp4, a_sf = torch.ops.trtllm.fp4_quantize(a.half().cuda(),
|
|
a_global_sf.cuda(),
|
|
sf_vec_size, False)
|
|
b_fp4, b_sf = torch.ops.trtllm.fp4_quantize(b.half().cuda(),
|
|
b_global_sf.cuda(),
|
|
sf_vec_size, False)
|
|
|
|
a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(a_fp4.cpu(), a_sf.cpu(),
|
|
1 / a_global_sf,
|
|
sf_vec_size)
|
|
torch.cuda.synchronize()
|
|
|
|
b_pt = e2m1_and_ufp8_scale_to_float_tensor_v2(b_fp4.cpu(), b_sf.cpu(),
|
|
1 / b_global_sf,
|
|
sf_vec_size)
|
|
|
|
c_ref = torch.ops.trtllm.fp4_gemm(a_fp4, b_fp4, a_sf, b_sf,
|
|
ab_global_sf, False)
|
|
|
|
best_config_idx = profiler.get_best_config_id(m, n, k)
|
|
c_actual = profiler.run_gemm(a_fp4, b_fp4, a_sf, b_sf, ab_global_sf,
|
|
False, best_config_idx)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(c_actual, c_ref, atol=1e-2, rtol=0)
|