TensorRT-LLMs/tests/quantization/_utils.py
石晓伟 548b5b7310
Update TensorRT-LLM (#2532)
* blossom-ci.yml: run vulnerability scan on blossom

* open source efb18c1256f8c9c3d47b7d0c740b83e5d5ebe0ec

---------

Co-authored-by: niukuo <6831097+niukuo@users.noreply.github.com>
Co-authored-by: pei0033 <59505847+pei0033@users.noreply.github.com>
Co-authored-by: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2024-12-04 21:16:56 +08:00

255 lines
9.2 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import tensorrt_llm
def woq_torch_dtype(dtype):
if dtype == "float16":
torch_dtype = torch.half
elif dtype == "bfloat16":
torch_dtype = torch.bfloat16
else:
assert False, f"{dtype} does not support WoQ"
return torch_dtype
def woq_all_ones(n, k, dtype):
torch_dtype = woq_torch_dtype(dtype)
# Init operands for multiplication in int32
weight = torch.ones((n, k), dtype=torch_dtype, device="cuda")
return weight
def woq_all_zeros(n, k, dtype):
torch_dtype = woq_torch_dtype(dtype)
# Init operands for multiplication in int32
weight = torch.zeros((n, k), dtype=torch_dtype, device="cuda")
return weight
def woq_gen_weights(n, k, dtype):
torch_dtype = woq_torch_dtype(dtype)
# Init operands for multiplication in int32
weight = torch.rand((n, k), dtype=torch_dtype, device="cuda") * 2 - 1.0
return weight
def woq_conversion(weight, wTypeId):
# only support int8 weight only
if wTypeId == 1:
torch_wTypeId = torch.int8
elif wTypeId == 2:
torch_wTypeId = torch.quint4x2
else:
assert False, f"wTypeId={wTypeId} is not supported by WoQ"
return torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(
weight.cpu(), torch_wTypeId)
def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None):
ref = torch.matmul(mat1, ref_torch_weights)
if bias is not None:
ref += bias
return ref
def woq_gt_matmul(m,
mat1,
ref_torch_weights,
torch_weight_scales,
dtype,
bias=None):
mat1 = mat1.to(dtype=torch.float)
ref_torch_weights = ref_torch_weights.to(dtype=torch.float)
# Do matmul
ref = torch.matmul(mat1, ref_torch_weights)
# Prepare per element scaling
scaling = torch_weight_scales.expand((m, -1))
# Scale output and cast to right type
ref = ref * scaling
# Round to the nearest int to match cuda rounding
if dtype == "int32":
ref = torch.round(ref)
# Cast ref to the required output typy
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
if bias is not None:
ref += bias
return ref
def woq_assert_near_eq(ref, act, wTypeId):
# match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp
if wTypeId == 1:
bits_in_type = 8
else:
bits_in_type = 4
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
max_val = torch.max(abs(ref)).item()
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7)
def gt_matmul_smooth_quant(mat1, mat2, scale_a_, scale_b_, dtype, bias=None):
# Convert to int32 for PyTorch GT Matmul with accumulation in int32.
device = mat1.device
mat1 = mat1.to(dtype=torch.int32).cpu()
# Transpose the second matrix to support the native PyTorch format
mat2 = mat2.transpose(0, 1).to(dtype=torch.int32).cpu()
# Do matmul, int32 matmul must be in CPU. GPU does not support
ref = torch.matmul(mat1, mat2)
ref = ref.to(device)
m = 1
for ii in range(len(mat1.shape) - 1):
m *= mat1.shape[ii]
n = mat2.shape[1]
# Prepare per element scaling
scale_a = scale_a_.expand((m, 1)).float()
scale_b = scale_b_.expand((1, n)).float()
scaling = torch.matmul(scale_a, scale_b).reshape(ref.shape)
# Scale output and cast to right type
ref = ref * scaling
# Round to the nearest int to match cuda rounding
if dtype == "int32":
ref = torch.round(ref)
# Cast ref to the required output type
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
if bias is not None:
ref += bias
return ref
# Note: The qweight here should be torch Tensors before being processed by `qserve_convert_linear`
# i.e., without reordering and packing. There should be no packing and reordering for scales and zeros as well.
# The zeros here are scaled zeros
def gt_qserve_gemm_per_group(qact: torch.IntTensor,
act_scales: torch.HalfTensor,
qweight: torch.IntTensor,
s1_scales: torch.HalfTensor,
s2_scales: torch.IntTensor,
s2_szeros: torch.IntTensor,
group_size=128) -> torch.HalfTensor:
out_features = qweight.shape[0]
in_features = qweight.shape[1]
# Step 1: Dequantize weight from int4 to int8
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
1).to(qweight.device)
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
1).to(qweight.device)
assert qweight.dtype == torch.int8
# The kernel relies on two's complement arithmetic of int8.
# If qweight is converted to int32 the result will not match the kernel.
dequantized_weight = qweight.reshape(
out_features, in_features // group_size,
group_size).mul(s2_scales).sub(s2_szeros)
dequantized_weight = dequantized_weight.reshape(out_features, in_features)
# Step 2: Perform matrix multiplication in int32
result = torch.matmul(qact.to(torch.int32),
dequantized_weight.T.to(torch.int32))
# Step 3: Dequantize the result to float
# Convert int GEMM result, ascales and wscales all to float, which is aligned with the QServe GEMM kernel.
result = result.float()
s1_scales = s1_scales.reshape(1, out_features).to(result.device).float()
act_scales = act_scales.reshape(act_scales.shape[0],
1).to(result.device).float()
# To match the result exactly to QServe, the multiplication order must be preserved due to float rounding errors.
result = result.mul(s1_scales.mul(act_scales))
return result.half()
def gt_qserve_gemm_per_channel(qact: torch.IntTensor,
act_scales: torch.HalfTensor,
act_sums: torch.HalfTensor,
qweight: torch.CharTensor,
s1_scales: torch.HalfTensor,
s1_szeros: torch.HalfTensor) -> torch.HalfTensor:
out_features = qweight.shape[0]
qweight.shape[1]
num_activations = qact.shape[0]
# Step 1: Perform matrix multiplication in int32
result = torch.matmul(qact.to(torch.int32), qweight.T.to(torch.int32))
# Step 2: Dequantize the result to float
# Convert int GEMM result, ascales and wscales all to float, which is aligned with the QServe GEMM kernel.
result = result.float()
s1_scales = s1_scales.reshape(1, out_features).to(result.device).float()
act_scales = act_scales.reshape(act_scales.shape[0],
1).to(result.device).float()
# To match the result exactly to QServe, the multiplication order must be preserved due to float rounding errors.
result = result.mul(s1_scales.mul(act_scales))
# Step 3: Add the outer product between act_sums and s1_szeros
# Note: no unary minus before zeros like in per-channel version.
act_sums = act_sums.reshape(num_activations, 1).to(result.device).float()
s1_szeros = s1_szeros.reshape(1, out_features).to(result.device).float()
result = result - act_sums * s1_szeros
return result.half()
def gt_matmul_fp8_rowwise(mat1, mat2, scale_a_, scale_b_, dtype, bias=None):
# Convert to float32 for PyTorch GT Matmul with accumulation in float32.
device = mat1.device
mat1 = mat1.to(dtype=torch.float32)
# Transpose the second matrix to support the native PyTorch format
mat2 = mat2.transpose(0, 1).to(dtype=torch.float32)
# Do matmul, float32 matmul must be in CPU. GPU does not support
ref = torch.matmul(mat1, mat2)
ref = ref.to(device)
m = 1
for ii in range(len(mat1.shape) - 1):
m *= mat1.shape[ii]
n = mat2.shape[1]
# Prepare per element scaling
scale_a = scale_a_.expand((m, 1)).float()
scale_b = scale_b_.expand((1, n)).float()
scaling = torch.matmul(scale_a, scale_b).reshape(ref.shape)
# Scale output and cast to right type
ref = ref * scaling
# Cast ref to the required output type
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
if bias is not None:
ref += bias
return ref
def gt_quantize_per_token(x):
xmax, _ = x.abs().max(dim=-1, keepdim=True)
x = (x * 127.0 / xmax).round().clip(-128, 127).to(dtype=torch.int8)
scale_act = (xmax / 127.0).reshape(-1, 1)
return x, scale_act