TensorRT-LLMs/tests/test_export.py
2024-09-30 16:20:23 +08:00

93 lines
3.1 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from tensorrt_llm.models.convert_utils import generate_int8
def dist(x, y):
x = x.flatten().to(float)
y = y.flatten().to(float)
l2_x = np.linalg.norm(x)
l2_y = np.linalg.norm(y)
l2_diff = np.linalg.norm(x - y)
cos = np.dot(x, y) / (l2_x * l2_y)
angle = np.arccos(cos.clip(-1, 1)) * 180 / np.pi
return l2_diff, angle
class TestINT8Export(unittest.TestCase):
def setUp(self):
self.rng = np.random.default_rng(42)
w = 2 * (self.rng.random([128, 256]) - 0.5) # weights in [-1, 1]
x = 10 * (self.rng.random([5, 128]) - 0.5) # x in [-5, 5]
y = x @ w
ranges = {
"x": torch.from_numpy(np.abs(x).max(axis=0)),
"y": torch.from_numpy(np.abs(y).max(axis=0)),
"w": torch.from_numpy(np.abs(w).max(axis=0)),
}
x = torch.from_numpy(x)
y = torch.from_numpy(y)
w = torch.from_numpy(w)
values = generate_int8(w, ranges)
self.x, self.y, self.w = x, y, w
self.values = values
def test_weight_quantization(self):
w_angle = dist(
self.values["weight.int8"] * self.values["scale_w_quant_orig"],
self.w)[1]
w_angle_col = dist(
self.values["weight.int8.col"] *
self.values["scale_w_quant_orig.col"], self.w)[1]
self.assertTrue(np.abs(w_angle) < 0.5)
self.assertTrue(np.abs(w_angle_col) < 0.5)
self.assertTrue(np.abs(w_angle_col) < np.abs(w_angle))
def test_e2e_gemm_quantization(self):
# mimic what CUTLASS would do
x_i8 = (self.x * self.values["scale_x_orig_quant"]).round().clip(
-127, 127).to(torch.int32)
# import pdb
# pdb.set_trace()
y_i32 = x_i8 @ self.values["weight.int8"].to(torch.int32)
y_quant = y_i32 * self.values["scale_y_accum_quant"] * self.values[
"scale_y_quant_orig"]
y_angle = dist(self.y, y_quant)[1]
y_i32_col = x_i8 @ self.values["weight.int8.col"].to(torch.int32)
y_quant_col = y_i32_col * self.values[
"scale_y_accum_quant.col"] * self.values["scale_y_quant_orig"]
y_angle_col = dist(self.y, y_quant_col)[1]
self.assertTrue(np.abs(y_angle) < 0.5)
self.assertTrue(np.abs(y_angle_col) < 0.5)
self.assertTrue(np.abs(y_angle_col) < np.abs(y_angle))
if __name__ == '__main__':
unittest.main()