mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* blossom-ci.yml: run vulnerability scan on blossom * open source efb18c1256f8c9c3d47b7d0c740b83e5d5ebe0ec --------- Co-authored-by: niukuo <6831097+niukuo@users.noreply.github.com> Co-authored-by: pei0033 <59505847+pei0033@users.noreply.github.com> Co-authored-by: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
255 lines
9.2 KiB
Python
255 lines
9.2 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
|
|
|
|
def woq_torch_dtype(dtype):
|
|
if dtype == "float16":
|
|
torch_dtype = torch.half
|
|
elif dtype == "bfloat16":
|
|
torch_dtype = torch.bfloat16
|
|
else:
|
|
assert False, f"{dtype} does not support WoQ"
|
|
return torch_dtype
|
|
|
|
|
|
def woq_all_ones(n, k, dtype):
|
|
torch_dtype = woq_torch_dtype(dtype)
|
|
# Init operands for multiplication in int32
|
|
weight = torch.ones((n, k), dtype=torch_dtype, device="cuda")
|
|
return weight
|
|
|
|
|
|
def woq_all_zeros(n, k, dtype):
|
|
torch_dtype = woq_torch_dtype(dtype)
|
|
# Init operands for multiplication in int32
|
|
weight = torch.zeros((n, k), dtype=torch_dtype, device="cuda")
|
|
return weight
|
|
|
|
|
|
def woq_gen_weights(n, k, dtype):
|
|
torch_dtype = woq_torch_dtype(dtype)
|
|
# Init operands for multiplication in int32
|
|
weight = torch.rand((n, k), dtype=torch_dtype, device="cuda") * 2 - 1.0
|
|
return weight
|
|
|
|
|
|
def woq_conversion(weight, wTypeId):
|
|
# only support int8 weight only
|
|
if wTypeId == 1:
|
|
torch_wTypeId = torch.int8
|
|
elif wTypeId == 2:
|
|
torch_wTypeId = torch.quint4x2
|
|
else:
|
|
assert False, f"wTypeId={wTypeId} is not supported by WoQ"
|
|
return torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(
|
|
weight.cpu(), torch_wTypeId)
|
|
|
|
|
|
def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None):
|
|
ref = torch.matmul(mat1, ref_torch_weights)
|
|
if bias is not None:
|
|
ref += bias
|
|
return ref
|
|
|
|
|
|
def woq_gt_matmul(m,
|
|
mat1,
|
|
ref_torch_weights,
|
|
torch_weight_scales,
|
|
dtype,
|
|
bias=None):
|
|
mat1 = mat1.to(dtype=torch.float)
|
|
ref_torch_weights = ref_torch_weights.to(dtype=torch.float)
|
|
# Do matmul
|
|
ref = torch.matmul(mat1, ref_torch_weights)
|
|
|
|
# Prepare per element scaling
|
|
scaling = torch_weight_scales.expand((m, -1))
|
|
# Scale output and cast to right type
|
|
ref = ref * scaling
|
|
|
|
# Round to the nearest int to match cuda rounding
|
|
if dtype == "int32":
|
|
ref = torch.round(ref)
|
|
|
|
# Cast ref to the required output typy
|
|
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
|
|
|
|
if bias is not None:
|
|
ref += bias
|
|
|
|
return ref
|
|
|
|
|
|
def woq_assert_near_eq(ref, act, wTypeId):
|
|
# match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp
|
|
if wTypeId == 1:
|
|
bits_in_type = 8
|
|
else:
|
|
bits_in_type = 4
|
|
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
|
|
|
|
max_val = torch.max(abs(ref)).item()
|
|
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
|
|
torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7)
|
|
|
|
|
|
def gt_matmul_smooth_quant(mat1, mat2, scale_a_, scale_b_, dtype, bias=None):
|
|
# Convert to int32 for PyTorch GT Matmul with accumulation in int32.
|
|
device = mat1.device
|
|
mat1 = mat1.to(dtype=torch.int32).cpu()
|
|
# Transpose the second matrix to support the native PyTorch format
|
|
mat2 = mat2.transpose(0, 1).to(dtype=torch.int32).cpu()
|
|
# Do matmul, int32 matmul must be in CPU. GPU does not support
|
|
ref = torch.matmul(mat1, mat2)
|
|
ref = ref.to(device)
|
|
|
|
m = 1
|
|
for ii in range(len(mat1.shape) - 1):
|
|
m *= mat1.shape[ii]
|
|
n = mat2.shape[1]
|
|
|
|
# Prepare per element scaling
|
|
scale_a = scale_a_.expand((m, 1)).float()
|
|
scale_b = scale_b_.expand((1, n)).float()
|
|
scaling = torch.matmul(scale_a, scale_b).reshape(ref.shape)
|
|
# Scale output and cast to right type
|
|
ref = ref * scaling
|
|
|
|
# Round to the nearest int to match cuda rounding
|
|
if dtype == "int32":
|
|
ref = torch.round(ref)
|
|
|
|
# Cast ref to the required output type
|
|
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
|
|
|
|
if bias is not None:
|
|
ref += bias
|
|
|
|
return ref
|
|
|
|
|
|
# Note: The qweight here should be torch Tensors before being processed by `qserve_convert_linear`
|
|
# i.e., without reordering and packing. There should be no packing and reordering for scales and zeros as well.
|
|
# The zeros here are scaled zeros
|
|
def gt_qserve_gemm_per_group(qact: torch.IntTensor,
|
|
act_scales: torch.HalfTensor,
|
|
qweight: torch.IntTensor,
|
|
s1_scales: torch.HalfTensor,
|
|
s2_scales: torch.IntTensor,
|
|
s2_szeros: torch.IntTensor,
|
|
group_size=128) -> torch.HalfTensor:
|
|
out_features = qweight.shape[0]
|
|
in_features = qweight.shape[1]
|
|
|
|
# Step 1: Dequantize weight from int4 to int8
|
|
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
|
|
1).to(qweight.device)
|
|
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
|
|
1).to(qweight.device)
|
|
|
|
assert qweight.dtype == torch.int8
|
|
# The kernel relies on two's complement arithmetic of int8.
|
|
# If qweight is converted to int32 the result will not match the kernel.
|
|
dequantized_weight = qweight.reshape(
|
|
out_features, in_features // group_size,
|
|
group_size).mul(s2_scales).sub(s2_szeros)
|
|
dequantized_weight = dequantized_weight.reshape(out_features, in_features)
|
|
|
|
# Step 2: Perform matrix multiplication in int32
|
|
result = torch.matmul(qact.to(torch.int32),
|
|
dequantized_weight.T.to(torch.int32))
|
|
|
|
# Step 3: Dequantize the result to float
|
|
# Convert int GEMM result, ascales and wscales all to float, which is aligned with the QServe GEMM kernel.
|
|
result = result.float()
|
|
s1_scales = s1_scales.reshape(1, out_features).to(result.device).float()
|
|
act_scales = act_scales.reshape(act_scales.shape[0],
|
|
1).to(result.device).float()
|
|
# To match the result exactly to QServe, the multiplication order must be preserved due to float rounding errors.
|
|
result = result.mul(s1_scales.mul(act_scales))
|
|
return result.half()
|
|
|
|
|
|
def gt_qserve_gemm_per_channel(qact: torch.IntTensor,
|
|
act_scales: torch.HalfTensor,
|
|
act_sums: torch.HalfTensor,
|
|
qweight: torch.CharTensor,
|
|
s1_scales: torch.HalfTensor,
|
|
s1_szeros: torch.HalfTensor) -> torch.HalfTensor:
|
|
out_features = qweight.shape[0]
|
|
qweight.shape[1]
|
|
num_activations = qact.shape[0]
|
|
|
|
# Step 1: Perform matrix multiplication in int32
|
|
result = torch.matmul(qact.to(torch.int32), qweight.T.to(torch.int32))
|
|
|
|
# Step 2: Dequantize the result to float
|
|
# Convert int GEMM result, ascales and wscales all to float, which is aligned with the QServe GEMM kernel.
|
|
result = result.float()
|
|
s1_scales = s1_scales.reshape(1, out_features).to(result.device).float()
|
|
act_scales = act_scales.reshape(act_scales.shape[0],
|
|
1).to(result.device).float()
|
|
# To match the result exactly to QServe, the multiplication order must be preserved due to float rounding errors.
|
|
result = result.mul(s1_scales.mul(act_scales))
|
|
# Step 3: Add the outer product between act_sums and s1_szeros
|
|
# Note: no unary minus before zeros like in per-channel version.
|
|
act_sums = act_sums.reshape(num_activations, 1).to(result.device).float()
|
|
s1_szeros = s1_szeros.reshape(1, out_features).to(result.device).float()
|
|
result = result - act_sums * s1_szeros
|
|
|
|
return result.half()
|
|
|
|
|
|
def gt_matmul_fp8_rowwise(mat1, mat2, scale_a_, scale_b_, dtype, bias=None):
|
|
# Convert to float32 for PyTorch GT Matmul with accumulation in float32.
|
|
device = mat1.device
|
|
mat1 = mat1.to(dtype=torch.float32)
|
|
# Transpose the second matrix to support the native PyTorch format
|
|
mat2 = mat2.transpose(0, 1).to(dtype=torch.float32)
|
|
# Do matmul, float32 matmul must be in CPU. GPU does not support
|
|
ref = torch.matmul(mat1, mat2)
|
|
ref = ref.to(device)
|
|
|
|
m = 1
|
|
for ii in range(len(mat1.shape) - 1):
|
|
m *= mat1.shape[ii]
|
|
n = mat2.shape[1]
|
|
|
|
# Prepare per element scaling
|
|
scale_a = scale_a_.expand((m, 1)).float()
|
|
scale_b = scale_b_.expand((1, n)).float()
|
|
scaling = torch.matmul(scale_a, scale_b).reshape(ref.shape)
|
|
# Scale output and cast to right type
|
|
ref = ref * scaling
|
|
|
|
# Cast ref to the required output type
|
|
ref = ref.to(dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
|
|
|
|
if bias is not None:
|
|
ref += bias
|
|
|
|
return ref
|
|
|
|
|
|
def gt_quantize_per_token(x):
|
|
xmax, _ = x.abs().max(dim=-1, keepdim=True)
|
|
x = (x * 127.0 / xmax).round().clip(-128, 127).to(dtype=torch.int8)
|
|
scale_act = (xmax / 127.0).reshape(-1, 1)
|
|
return x, scale_act
|