mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
148 lines
5.4 KiB
Python
148 lines
5.4 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Parameter, Tensor
|
|
from tensorrt_llm._utils import torch_to_numpy
|
|
from tensorrt_llm.quantization.functional import smooth_quant_layer_norm
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.util import (create_session, run_session, skip_bf16_pre_ampere,
|
|
unittest_name_func)
|
|
|
|
|
|
class TestSmoothQuantLayerNorm(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
torch.manual_seed(1997)
|
|
tensorrt_llm.logger.set_level('error')
|
|
|
|
def load_test_cases():
|
|
test_cases = [('float16', False, True), ('float16', True, True),
|
|
('bfloat16', False, True), ('bfloat16', True, True),
|
|
('float32', False, True), ('float32', True, True),
|
|
('float16', True, False)]
|
|
test_cases = [i + (True, ) for i in test_cases
|
|
] + [i + (False, ) for i in test_cases]
|
|
return [i + (True, )
|
|
for i in test_cases] + [i + (False, ) for i in test_cases]
|
|
|
|
@parameterized.expand(load_test_cases, name_func=unittest_name_func)
|
|
def test_smooth_quant_layer_norm(self, dtype, dynamic_act_scaling,
|
|
elementwise_affine, remove_batch_dim,
|
|
use_plugin):
|
|
# Skip tests that are not supported in pre-ampere architecture
|
|
skip_bf16_pre_ampere(dtype)
|
|
|
|
# test data
|
|
hidden_size = 1024
|
|
x_data = torch.randn(
|
|
(8, 128, hidden_size) if not remove_batch_dim else
|
|
(8 * 128, hidden_size),
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device="cuda")
|
|
eps = 1e-5
|
|
|
|
m = torch.nn.LayerNorm(
|
|
hidden_size,
|
|
eps=eps,
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
elementwise_affine=elementwise_affine,
|
|
device="cuda")
|
|
|
|
# Scale to int
|
|
scale_data = torch.randint(2,
|
|
32, (1, ),
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
if use_plugin:
|
|
network.plugin_config.layernorm_quantization_plugin = dtype
|
|
with tensorrt_llm.net_guard(network):
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
weight = None
|
|
bias = None
|
|
if elementwise_affine:
|
|
gamma_data = m.weight.detach().cpu()
|
|
beta_data = m.bias.detach().cpu()
|
|
weight = Parameter(torch_to_numpy(gamma_data)).value
|
|
bias = Parameter(torch_to_numpy(beta_data)).value
|
|
scale = Parameter(torch_to_numpy(scale_data)).value
|
|
|
|
output = smooth_quant_layer_norm(
|
|
x,
|
|
hidden_size,
|
|
weight,
|
|
bias,
|
|
scale,
|
|
eps,
|
|
dynamic_act_scaling=dynamic_act_scaling)
|
|
|
|
if dynamic_act_scaling:
|
|
output, dynamic_scales = output
|
|
dynamic_scales.mark_output('dynamic_scales', 'float32')
|
|
|
|
output.mark_output('output', 'int8')
|
|
|
|
session = create_session(builder, network, precision=dtype, int8=True)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
def cast_to_int8_with_sat(tensor):
|
|
return tensor.round().clip(-128, 127).to(dtype=torch.int8)
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data).to(dtype=torch.float32)
|
|
if dynamic_act_scaling:
|
|
abs_max_f, _ = ref.abs().max(dim=-1, keepdim=True)
|
|
dynamic_scale = abs_max_f / 127.0
|
|
ref_quantized = cast_to_int8_with_sat(ref * (127.0 / abs_max_f))
|
|
else:
|
|
ref_quantized = cast_to_int8_with_sat(ref * scale_data)
|
|
|
|
# compare diff of quantized output
|
|
# Set absolute tolerance to 1 to mitigate some rounding error
|
|
torch.testing.assert_close(ref_quantized,
|
|
outputs['output'],
|
|
atol=1,
|
|
rtol=0)
|
|
|
|
# compare diff of dynamic activation scales
|
|
if dynamic_act_scaling:
|
|
torch.testing.assert_close(dynamic_scale,
|
|
outputs['dynamic_scales'],
|
|
atol=1e-2,
|
|
rtol=1e-2)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|