TensorRT-LLMs/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp
Yihan Wang 9df4dad3b6
[None][fix] Introduce inline namespace to avoid symbol collision (#9541)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
2025-12-12 23:32:15 +08:00

441 lines
18 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
#include "tensorrt_llm/thop/thUtils.h"
#if defined(TORCH_VERSION_MAJOR) \
&& ((TORCH_VERSION_MAJOR > 1) || ((TORCH_VERSION_MAJOR == 1) && (TORCH_VERSION_MINOR >= 9)))
#define TORCH_IS_AT_LEAST_v190
#endif
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
using torch::Tensor;
using namespace tensorrt_llm::kernels::cutlass_kernels;
void check_quant_type_allowed(torch::ScalarType quant_type)
{
#ifdef TORCH_IS_AT_LEAST_v190
TORCH_CHECK(
quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
#else
TORCH_CHECK(quant_type == torch::kInt8, "Must be int8 quantization");
#endif
}
QuantType get_ft_quant_type(torch::ScalarType quant_type, torch::ScalarType activation_type = torch::kFloat16)
{
// Actually we need FP8 here, but current torch version does not support FP8. That's why INT8 is employed here
if (activation_type == torch::kFloat8_e4m3fn)
{
return QuantType::W4_AFP8;
}
else if (quant_type == torch::kInt8)
{
return QuantType::W8_A16;
}
#ifdef TORCH_IS_AT_LEAST_v190
else if (quant_type == at::ScalarType::QUInt4x2)
{
return QuantType::W4_A16;
}
#endif
else
{
TORCH_CHECK(false, "Invalid quantization type");
}
}
// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures.
Tensor permute_B_rows_for_mixed_gemm(Tensor quantized_tensor, torch::ScalarType quant_type, const int64_t arch_version)
{
auto _st = quantized_tensor.scalar_type();
CHECK_CPU(quantized_tensor);
CHECK_CONTIGUOUS(quantized_tensor);
TORCH_CHECK(_st == torch::kInt8, "Quantized tensor must be int8 dtype");
check_quant_type_allowed(quant_type);
TORCH_CHECK(
quantized_tensor.dim() == 2 || quantized_tensor.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
QuantType ft_quant_type = get_ft_quant_type(quant_type);
const size_t bits_in_quant_type = get_weight_quant_bits(ft_quant_type);
const size_t num_experts = quantized_tensor.dim() == 2 ? 1 : quantized_tensor.size(0);
const size_t num_rows = quantized_tensor.size(-2);
const size_t num_cols = (8 / bits_in_quant_type) * quantized_tensor.size(-1);
Tensor transformed_tensor = torch::empty_like(quantized_tensor);
int8_t* input_byte_ptr = get_ptr<int8_t>(quantized_tensor);
int8_t* output_byte_ptr = get_ptr<int8_t>(transformed_tensor);
permute_B_rows_for_mixed_gemm(
output_byte_ptr, input_byte_ptr, {num_experts, num_rows, num_cols}, ft_quant_type, arch_version);
return transformed_tensor;
}
// We need to use this transpose to correctly handle packed int4 and int8 data
Tensor subbyte_transpose(Tensor quantized_tensor, torch::ScalarType quant_type)
{
auto _st = quantized_tensor.scalar_type();
CHECK_CPU(quantized_tensor);
CHECK_CONTIGUOUS(quantized_tensor);
TORCH_CHECK(_st == torch::kInt8, "Quantized tensor must be int8 dtype");
check_quant_type_allowed(quant_type);
TORCH_CHECK(
quantized_tensor.dim() == 2 || quantized_tensor.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
QuantType ft_quant_type = get_ft_quant_type(quant_type);
const size_t bits_in_quant_type = get_weight_quant_bits(ft_quant_type);
const size_t num_experts = quantized_tensor.dim() == 2 ? 1 : quantized_tensor.size(0);
const size_t num_rows = quantized_tensor.size(-2);
const size_t num_cols = (8 / bits_in_quant_type) * quantized_tensor.size(-1);
Tensor transposed_tensor = torch::empty_like(quantized_tensor);
int8_t* input_byte_ptr = get_ptr<int8_t>(quantized_tensor);
int8_t* output_byte_ptr = get_ptr<int8_t>(transposed_tensor);
subbyte_transpose(output_byte_ptr, input_byte_ptr, {num_experts, num_rows, num_cols}, ft_quant_type);
return transposed_tensor;
}
// NOTE: TODO this API must change to take in the quant type. We must know the intended activation type since
// W4A8 and W4A16 have different layouts.
Tensor preprocess_weights_for_mixed_gemm(
Tensor row_major_quantized_weight, torch::ScalarType quant_type, torch::ScalarType activation_type)
{
auto _st = row_major_quantized_weight.scalar_type();
CHECK_CPU(row_major_quantized_weight);
CHECK_CONTIGUOUS(row_major_quantized_weight);
TORCH_CHECK(_st == torch::kInt8, "Quantized tensor must be int8 dtype");
check_quant_type_allowed(quant_type);
TORCH_CHECK(row_major_quantized_weight.dim() == 2 || row_major_quantized_weight.dim() == 3,
"Invalid dim. The dim of weight should be 2 or 3");
QuantType ft_quant_type = get_ft_quant_type(quant_type, activation_type);
const size_t bits_in_quant_type = get_weight_quant_bits(ft_quant_type);
const size_t num_experts = row_major_quantized_weight.dim() == 2 ? 1 : row_major_quantized_weight.size(0);
const size_t num_rows = row_major_quantized_weight.size(-2);
const size_t num_cols = (8 / bits_in_quant_type) * row_major_quantized_weight.size(-1);
Tensor processed_tensor = torch::zeros_like(row_major_quantized_weight);
int8_t* input_byte_ptr = get_ptr<int8_t>(row_major_quantized_weight);
int8_t* output_byte_ptr = get_ptr<int8_t>(processed_tensor);
bool force_interleave = row_major_quantized_weight.dim() == 3; // WAR for MoE 3-D tensors.
preprocess_weights_for_mixed_gemm(
output_byte_ptr, input_byte_ptr, {num_experts, num_rows, num_cols}, ft_quant_type, force_interleave);
return processed_tensor;
}
std::vector<Tensor> symmetric_quantize_helper(
Tensor weight, torch::ScalarType quant_type, bool return_unprocessed_quantized_tensor)
{
CHECK_CPU(weight);
CHECK_CONTIGUOUS(weight);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
auto _st = weight.scalar_type();
TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16 || _st == torch::kBFloat16,
"Invalid datatype. Weight must be FP16 or BF16");
check_quant_type_allowed(quant_type);
QuantType ft_quant_type = get_ft_quant_type(quant_type);
const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
const size_t num_rows = weight.size(-2);
const size_t num_cols = weight.size(-1);
const size_t bits_in_type = get_weight_quant_bits(ft_quant_type);
const size_t bytes_per_out_col = num_cols * bits_in_type / 8;
std::vector<int64_t> quantized_weight_shape;
std::vector<int64_t> scale_shape;
if (weight.dim() == 2)
{
quantized_weight_shape = {int64_t(num_rows), int64_t(bytes_per_out_col)};
scale_shape = {int64_t(num_cols)};
}
else if (weight.dim() == 3)
{
quantized_weight_shape = {int64_t(num_experts), int64_t(num_rows), int64_t(bytes_per_out_col)};
scale_shape = {int64_t(num_experts), int64_t(num_cols)};
}
else
{
TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
}
Tensor unprocessed_quantized_weight
= torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);
Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));
int8_t* unprocessed_quantized_weight_ptr = get_ptr<int8_t>(unprocessed_quantized_weight);
int8_t* processed_quantized_weight_ptr = get_ptr<int8_t>(processed_quantized_weight);
// TODO This should be removed if Grouped GEMM is updated to not need interleaved input
bool force_interleave = weight.dim() == 3;
if (weight.scalar_type() == at::ScalarType::Float)
{
symmetric_quantize<float, float>(processed_quantized_weight_ptr, unprocessed_quantized_weight_ptr,
get_ptr<float>(scales), get_ptr<float const>(weight), {num_experts, num_rows, num_cols}, ft_quant_type,
force_interleave);
}
else if (weight.scalar_type() == at::ScalarType::Half)
{
symmetric_quantize<half, half>(processed_quantized_weight_ptr, unprocessed_quantized_weight_ptr,
get_ptr<half>(scales), get_ptr<half const>(weight), {num_experts, num_rows, num_cols}, ft_quant_type,
force_interleave);
}
#ifdef ENABLE_BF16
else if (weight.scalar_type() == at::ScalarType::BFloat16)
{
symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(processed_quantized_weight_ptr,
unprocessed_quantized_weight_ptr, get_ptr<__nv_bfloat16>(scales), get_ptr<__nv_bfloat16 const>(weight),
{num_experts, num_rows, num_cols}, ft_quant_type, force_interleave);
}
#endif
else
{
TORCH_CHECK(false, "Invalid datatype. Weight must be BF16/FP16");
}
if (return_unprocessed_quantized_tensor)
{
return std::vector<Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
}
return std::vector<Tensor>{processed_quantized_weight, scales};
}
std::vector<Tensor> symmetric_quantize_last_axis_of_batched_matrix(Tensor weight, torch::ScalarType quant_type)
{
return symmetric_quantize_helper(weight, quant_type, false);
}
// Same as symmetric_quantize_last_axis_of_batched_matrix but returns a tuple of:
// (unprocessed_quantized_weights, preprocessed_quantized_weights, scales)
// Exposed mainly for testing, so that the unprocessed weights can be passed to torch functions.
std::vector<Tensor> _symmetric_quantize_last_axis_of_batched_matrix(Tensor weight, torch::ScalarType quant_type)
{
return symmetric_quantize_helper(weight, quant_type, true);
}
Tensor add_bias_and_interleave_int4s(Tensor weight)
{
CHECK_CPU(weight);
CHECK_CONTIGUOUS(weight);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kInt8, "Weight must be a packed int8 tensor");
Tensor output = weight.clone().detach();
int8_t* int4_tensor_ptr = get_ptr<int8_t>(output);
const size_t num_bytes = output.numel();
const size_t num_elts = 2 * num_bytes;
add_bias_and_interleave_quantized_tensor_inplace(int4_tensor_ptr, num_elts, QuantType::W4_A16);
return output;
}
Tensor add_bias_and_interleave_int8s(Tensor weight)
{
CHECK_CPU(weight);
CHECK_CONTIGUOUS(weight);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kInt8, "Weight must be an int8 tensor");
Tensor output = weight.clone().detach();
int8_t* int8_tensor_ptr = get_ptr<int8_t>(output);
const size_t num_elts = output.numel();
add_bias_and_interleave_quantized_tensor_inplace(int8_tensor_ptr, num_elts, QuantType::W8_A16);
return output;
}
Tensor unpack_int4_packed_tensor_to_int8(Tensor weight)
{
CHECK_CPU(weight);
CHECK_CONTIGUOUS(weight);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kInt8, "Weight must be a packed int8 tensor");
std::vector<int64_t> int8_tensor_size(weight.dim());
for (int i = 0; i < weight.dim(); ++i)
{
int8_tensor_size[i] = weight.size(i);
}
int8_tensor_size[weight.dim() - 1] *= 2;
Tensor unpacked_weight
= torch::zeros(int8_tensor_size, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
int8_t* packed_ptr = get_ptr<int8_t>(weight);
int8_t* unpacked_ptr = get_ptr<int8_t>(unpacked_weight);
for (int64_t packed_idx = 0; packed_idx < weight.numel(); ++packed_idx)
{
int8_t packed_data = packed_ptr[packed_idx];
int8_t elt_0 = (int8_t(packed_data << 4) >> 4); // The double shift here is to ensure sign extension
int8_t elt_1 = packed_data >> 4;
unpacked_ptr[2 * packed_idx + 0] = elt_0;
unpacked_ptr[2 * packed_idx + 1] = elt_1;
}
return unpacked_weight;
}
Tensor pack_int8_tensor_to_packed_int4(Tensor weight)
{
CHECK_CPU(weight);
CHECK_CONTIGUOUS(weight);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kInt8, "Weight must be a int8 tensor");
std::vector<int64_t> packed_tensor_size(weight.dim());
for (int i = 0; i < weight.dim(); ++i)
{
packed_tensor_size[i] = weight.size(i);
}
packed_tensor_size[weight.dim() - 1] = (packed_tensor_size[weight.dim() - 1] + 1) / 2;
Tensor packed_weight
= torch::zeros(packed_tensor_size, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
int8_t* unpacked_ptr = get_ptr<int8_t>(weight);
int8_t* packed_ptr = get_ptr<int8_t>(packed_weight);
for (int64_t packed_idx = 0; packed_idx < packed_weight.numel(); ++packed_idx)
{
int8_t packed_int4s = 0;
int8_t elt_0 = unpacked_ptr[2 * packed_idx + 0];
int8_t elt_1 = unpacked_ptr[2 * packed_idx + 1];
TORCH_CHECK(elt_0 >= -8 && elt_0 <= 7, "Value in unpacked tensor not in int4 range");
TORCH_CHECK(elt_1 >= -8 && elt_1 <= 7, "Value in unpacked tensor not in int4 range");
packed_int4s |= ((elt_0 & 0x0F));
packed_int4s |= int8_t(elt_1 << 4);
packed_ptr[packed_idx] = packed_int4s;
}
return packed_weight;
}
Tensor mxfp4_dequantize_unswizzled(Tensor weight, Tensor scale, int64_t group_size)
{
// weight (n, k / 2)
// scale (n, k / group_size)
CHECK_CPU(weight);
CHECK_CPU(scale);
CHECK_CONTIGUOUS(weight);
CHECK_CONTIGUOUS(scale);
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kUInt8, "Weight must be a packed int8 tensor");
TORCH_CHECK(scale.dtype() == torch::kUInt8, "Scale must be a int8 tensor");
TORCH_CHECK(weight.size(0) == scale.size(0))
TORCH_CHECK(weight.size(1) * 2 == scale.size(1) * group_size)
uint8_t* weight_packed_ptr = get_ptr<uint8_t>(weight);
__nv_fp8_e8m0* scale_ptr = reinterpret_cast<__nv_fp8_e8m0*>(get_ptr<uint8_t>(scale));
int const n = weight.size(0);
int const k = weight.size(1) * 2;
Tensor dequant_weight = torch::empty({n, k}, torch::dtype(torch::kFloat).device(torch::kCPU).requires_grad(false));
float* dequant_weight_ptr = get_ptr<float>(dequant_weight);
float fp4_lut[] = {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0};
for (int packed_idx = 0; packed_idx < weight.numel(); ++packed_idx)
{
int8_t weight_packed_data = weight_packed_ptr[packed_idx];
uint8_t weight_low_ = weight_packed_data & 0xF;
uint8_t weight_high_ = (weight_packed_data & 0xF0) >> 4;
float weight_low = fp4_lut[weight_low_];
float weight_high = fp4_lut[weight_high_];
int scale_n_idx = packed_idx / (k / 2);
int scale_k_idx = ((packed_idx * 2) % k) / group_size;
float scale_ = static_cast<float>(scale_ptr[scale_n_idx * scale.size(1) + scale_k_idx]);
dequant_weight_ptr[2 * packed_idx] = weight_low * scale_;
dequant_weight_ptr[2 * packed_idx + 1] = weight_high * scale_;
}
return dequant_weight;
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
// Utility methods that may be useful for preprocessing weights in torch.
static auto symmetric_quantize_last_axis_of_batched_matrix
= torch::RegisterOperators("trtllm::symmetric_quantize_last_axis_of_batched_matrix",
&tensorrt_llm::torch_ext::symmetric_quantize_last_axis_of_batched_matrix);
static auto preprocess_weights_for_mixed_gemm = torch::RegisterOperators(
"trtllm::preprocess_weights_for_mixed_gemm", &tensorrt_llm::torch_ext::preprocess_weights_for_mixed_gemm);
static auto unpack_int4_packed_tensor_to_int8 = torch::RegisterOperators(
"trtllm::unpack_int4_packed_tensor_to_int8", &tensorrt_llm::torch_ext::unpack_int4_packed_tensor_to_int8);
static auto pack_int8_tensor_to_packed_int4 = torch::RegisterOperators(
"trtllm::pack_int8_tensor_to_packed_int4", &tensorrt_llm::torch_ext::pack_int8_tensor_to_packed_int4);
// Utility methods exposed purely for unit tests in torch.
static auto _symmetric_quantize_last_axis_of_batched_matrix
= torch::RegisterOperators("trtllm::_symmetric_quantize_last_axis_of_batched_matrix",
&tensorrt_llm::torch_ext::_symmetric_quantize_last_axis_of_batched_matrix);
static auto add_bias_and_interleave_int4s = torch::RegisterOperators(
"trtllm::_add_bias_and_interleave_int4s", &tensorrt_llm::torch_ext::add_bias_and_interleave_int4s);
static auto add_bias_and_interleave_int8s = torch::RegisterOperators(
"trtllm::_add_bias_and_interleave_int8s", &tensorrt_llm::torch_ext::add_bias_and_interleave_int8s);
static auto permute_B_rows_for_mixed_gemm = torch::RegisterOperators(
"trtllm::_permute_B_rows_for_mixed_gemm", &tensorrt_llm::torch_ext::permute_B_rows_for_mixed_gemm);
static auto subbyte_transpose
= torch::RegisterOperators("trtllm::_subbyte_transpose", &tensorrt_llm::torch_ext::subbyte_transpose);
static auto mxfp4_dequantize_unswizzled = torch::RegisterOperators(
"trtllm::mxfp4_dequantize_unswizzled", &tensorrt_llm::torch_ext::mxfp4_dequantize_unswizzled);