TensorRT-LLMs/tests/unittest/_torch/test_fp8_quantize.py
xiweny 6979afa6f2
test: reorganize tests folder hierarchy (#2996)
1. move TRT path tests to 'trt' folder
2. optimize some import usage
2025-03-27 12:07:53 +08:00

98 lines
3.7 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
from utils.util import getSMVersion
def _dequant_fp8(input, scale, transpose_scale, block_m, block_n):
input = input.to(torch.float)
scale = scale.to(torch.float)
if transpose_scale:
scale = scale.t()
output = torch.zeros_like(input)
m, n = input.shape
m_tile = 128 if block_m else 1
n_tile = 128 if block_n else 1
if m_tile == 1:
assert n % 16 == 0, "n must be divisible by 16"
total_blocks = math.ceil(n / 128)
for block in range(total_blocks):
# Calculate start position in 2D array
start_col = block * 128
end_col = min(start_col + 128, n)
output[:, start_col:
end_col] = input[:, start_col:end_col] * scale[:,
block].view(
-1, 1)
elif n_tile == 1:
assert m % 16 == 0, "m must be divisible by 16"
total_blocks = math.ceil(m / 128)
for block in range(total_blocks):
# Calculate start position in 2D array
start_row = block * 128
end_row = min(start_row + 128, m)
output[start_row:end_row, :] = input[start_row:end_row, :] * scale[
block, :]
else:
assert n % 16 == 0, "n must be divisible by 16"
assert m % 16 == 0, "m must be divisible by 16"
n_blocks = math.ceil(n / 128)
m_blocks = math.ceil(m / 128)
for i in range(n_blocks):
for j in range(m_blocks):
start_row = j * 128
end_row = min(start_row + 128, m)
start_col = i * 128
end_col = min(start_col + 128, n)
output[start_row:end_row,
start_col:end_col] = input[start_row:end_row,
start_col:end_col] * scale[j,
i]
return output
@pytest.mark.skipif(
getSMVersion() != 100 and getSMVersion() != 90,
reason="Only test on Blackwell and Hopper",
)
@pytest.mark.parametrize("k", [576, 256, 32])
@pytest.mark.parametrize(
"m",
[4, 16, 256],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
def test_fp8_quantize_blackwell(dtype, m, k):
torch.random.manual_seed(0)
# TODO: make sure there is no padding for now
assert m % 4 == 0, "Disable padding for now"
a = torch.randn((m, k), device='cuda', dtype=dtype)
fp8_a, fp8_a_scale = torch.ops.trtllm.fp8_quantize_1x128(a)
fp8_a_scale = fp8_a_scale.view(-1,
fp8_a.shape[0]) # transpose the scale view
a_dequant = _dequant_fp8(fp8_a, fp8_a_scale, True, False, True)
torch.testing.assert_close(a_dequant.cpu().to(torch.float32),
a.cpu().to(torch.float32),
atol=1e-1,
rtol=1e-1)