TensorRT-LLMs/tests/quantization/_utils.py
2024-08-29 17:25:07 +08:00

183 lines
5.6 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import tensorrt_llm
def woq_torch_dtype(dtype):
if dtype == "float16":
torch_dtype = torch.half
elif dtype == "bfloat16":
torch_dtype = torch.bfloat16
else:
assert False, f"{dtype} does not support WoQ"
return torch_dtype
def woq_all_ones(n, k, dtype):
torch_dtype = woq_torch_dtype(dtype)
# Init operands for multiplication in int32
weight = torch.ones((n, k), dtype=torch_dtype, device="cuda")
return weight
def woq_all_zeros(n, k, dtype):
torch_dtype = woq_torch_dtype(dtype)
# Init operands for multiplication in int32
weight = torch.zeros((n, k), dtype=torch_dtype, device="cuda")
return weight
def woq_gen_weights(n, k, dtype):
torch_dtype = woq_torch_dtype(dtype)
# Init operands for multiplication in int32
weight = torch.rand((n, k), dtype=torch_dtype, device="cuda") * 2 - 1.0
return weight
def woq_conversion(weight, wTypeId):
# only support int8 weight only
if wTypeId == 1:
torch_wTypeId = torch.int8
elif wTypeId == 2:
torch_wTypeId = torch.quint4x2
else:
assert False, f"wTypeId={wTypeId} is not supported by WoQ"
return torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(
weight.cpu(), torch_wTypeId)
def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None):
ref = torch.matmul(mat1, ref_torch_weights)
if bias is not None:
ref += bias
return ref
def woq_gt_matmul(m,
mat1,
ref_torch_weights,
torch_weight_scales,
dtype,
bias=None):
mat1 = mat1.to(dtype=torch.float)
ref_torch_weights = ref_torch_weights.to(dtype=torch.float)
# Do matmul
ref = torch.matmul(mat1, ref_torch_weights)
# Prepare per element scaling
scaling = torch_weight_scales.expand((m, -1))
# Scale output and cast to right type
ref = ref * scaling
# Round to the nearest int to match cuda rounding
if dtype == "int32":
ref = torch.round(ref)
# Cast ref to the required output typy
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
if bias is not None:
ref += bias
return ref
def woq_assert_near_eq(ref, act, wTypeId):
# match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp
if wTypeId == 1:
bits_in_type = 8
else:
bits_in_type = 4
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
max_val = torch.max(abs(ref)).item()
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7)
def gt_matmul_smooth_quant(mat1, mat2, scale_a_, scale_b_, dtype, bias=None):
# Convert to int32 for PyTorch GT Matmul with accumulation in int32.
device = mat1.device
mat1 = mat1.to(dtype=torch.int32).cpu()
# Transpose the second matrix to support the native PyTorch format
mat2 = mat2.transpose(0, 1).to(dtype=torch.int32).cpu()
# Do matmul, int32 matmul must be in CPU. GPU does not support
ref = torch.matmul(mat1, mat2)
ref = ref.to(device)
m = 1
for ii in range(len(mat1.shape) - 1):
m *= mat1.shape[ii]
n = mat2.shape[1]
# Prepare per element scaling
scale_a = scale_a_.expand((m, 1)).float()
scale_b = scale_b_.expand((1, n)).float()
scaling = torch.matmul(scale_a, scale_b).reshape(ref.shape)
# Scale output and cast to right type
ref = ref * scaling
# Round to the nearest int to match cuda rounding
if dtype == "int32":
ref = torch.round(ref)
# Cast ref to the required output type
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
if bias is not None:
ref += bias
return ref
def gt_matmul_fp8_rowwise(mat1, mat2, scale_a_, scale_b_, dtype, bias=None):
# Convert to float32 for PyTorch GT Matmul with accumulation in float32.
device = mat1.device
mat1 = mat1.to(dtype=torch.float32)
# Transpose the second matrix to support the native PyTorch format
mat2 = mat2.transpose(0, 1).to(dtype=torch.float32)
# Do matmul, float32 matmul must be in CPU. GPU does not support
ref = torch.matmul(mat1, mat2)
ref = ref.to(device)
m = 1
for ii in range(len(mat1.shape) - 1):
m *= mat1.shape[ii]
n = mat2.shape[1]
# Prepare per element scaling
scale_a = scale_a_.expand((m, 1)).float()
scale_b = scale_b_.expand((1, n)).float()
scaling = torch.matmul(scale_a, scale_b).reshape(ref.shape)
# Scale output and cast to right type
ref = ref * scaling
# Cast ref to the required output type
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
if bias is not None:
ref += bias
return ref
def gt_quantize_per_token(x):
xmax, _ = x.abs().max(dim=-1, keepdim=True)
x = (x * 127.0 / xmax).round().clip(-128, 127).to(dtype=torch.int8)
scale_act = (xmax / 127.0).reshape(-1, 1)
return x, scale_act