feat: W4A16 GEMM (#4232)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
This commit is contained in:
danielafrimi 2025-07-01 10:36:05 +03:00 committed by GitHub
parent 19c56f0374
commit 7a617ad1fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 721 additions and 20 deletions

View File

@ -533,7 +533,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
std::vector<CutlassGemmConfig> candidate_configs;
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
int const min_stages = int8_configs_only ? 3 : 2;
int const min_stages = (sm == 89) ? 3 : int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (auto const& tile_config : tiles)
{

View File

@ -84,7 +84,8 @@ add_library(
userbuffersTensor.cpp
weightOnlyQuantOp.cpp
mtpOp.cpp
loraOp.cpp)
loraOp.cpp
finegrained_mixed_dtype_gemm_thop.cpp)
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
${Python3_LIBRARIES} ${SHARED_TARGET})

View File

@ -0,0 +1,225 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "finegrained_mixed_dtype_gemm_thop.h"
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if defined(ENABLE_FP8) && defined(TRTLLM_CUDA_FP8_AVAILABLE)
#include <cuda_fp8.h>
#endif
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <map>
#include <mutex>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>
namespace torch_ext
{
W4A16GemmRunner::W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode)
: mActivationDtype(activationDtype)
{
if (quant_mode == 0)
{
if (activationDtype == at::ScalarType::Half)
{
mGemmRunner = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, half, half>>();
}
else if (activationDtype == at::ScalarType::BFloat16)
{
mGemmRunner = std::make_shared<
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>();
}
}
else if (quant_mode == 1)
{
if (activationDtype == at::ScalarType::Half)
{
mGemmRunner = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
}
else if (activationDtype == at::ScalarType::BFloat16)
{
mGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, __nv_bfloat16,
__nv_bfloat16, __nv_bfloat16>>();
}
}
else
{
TORCH_CHECK(false, "Unsupported quant mode for W4A16GemmRunner: ", quant_mode);
}
TORCH_CHECK(mGemmRunner, "Failed to create W4A16 GEMM runner for activation type ", c10::toString(activationDtype));
mConfigs = mGemmRunner->getConfigs(); // Get configs via the interface
TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for W4A16 GEMM with activation type ",
c10::toString(activationDtype));
}
at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales,
int64_t group_size_long, int64_t configIdx, std::optional<at::Tensor> bias, std::optional<at::Tensor> zeros) const
{
TORCH_CHECK(A.is_cuda() && B_packed.is_cuda() && scales.is_cuda(), "All input tensors must be on CUDA");
TORCH_CHECK(A.scalar_type() == mActivationDtype, "Activation tensor A's dtype ", c10::toString(A.scalar_type()),
" does not match runner's expected dtype ", c10::toString(mActivationDtype));
TORCH_CHECK(B_packed.scalar_type() == torch::kQUInt4x2 || B_packed.scalar_type() == torch::kInt8
|| B_packed.scalar_type() == torch::kUInt8,
"B_packed must be quint4x2, int8, or uint8 (view of quantized data)");
TORCH_CHECK(A.is_contiguous() && B_packed.is_contiguous() && scales.is_contiguous(),
"All input tensors (A, B_packed, scales) must be contiguous");
void const* zeros_ptr = nullptr;
if (zeros.has_value())
{
TORCH_CHECK(zeros.value().is_cuda(), "Zeros tensor must be on CUDA");
TORCH_CHECK(zeros.value().scalar_type() == torch::kFloat16 || zeros.value().scalar_type() == torch::kBFloat16,
"Zeros must be FP16 or BF16");
TORCH_CHECK(zeros.value().is_contiguous(), "Zeros tensor must be contiguous");
zeros_ptr = zeros.value().data_ptr();
}
void const* bias_ptr = nullptr;
if (bias.has_value())
{
TORCH_CHECK(bias.value().scalar_type() == torch::kFloat16 || bias.value().scalar_type() == torch::kBFloat16,
"Bias must be FP16 or BF16");
TORCH_CHECK(bias.value().is_contiguous(), "Bias tensor must be contiguous");
bias_ptr = bias.value().data_ptr();
}
int M = 0, K_act = 0;
// Logic to determine M and K_act from A_tensor dimensions
if (A.dim() == 2)
{
M = A.size(0);
K_act = A.size(1);
}
else
{ // A.dim() >= 3
M = A.size(0);
for (int i = 1; i < A.dim() - 1; ++i)
M *= A.size(i);
K_act = A.size(A.dim() - 1);
}
// Assuming B_packed is [K_weights, N_packed_int4_pairs] or similar
// K_weights should match K_act. N_orig is 2 * N_packed_int4_pairs
int K_weights = B_packed.size(0);
int N_packed_int4 = B_packed.size(1); // This is number of uint8_t elements, each holding two int4
int N_orig = N_packed_int4 * 2; // N_orig is the original N dimension
TORCH_CHECK(K_act == K_weights, "K dimension mismatch: A.shape[-1]=", K_act, " vs B_packed.shape[0]=", K_weights);
int K = K_act;
int group_size = static_cast<int>(group_size_long);
std::vector<int64_t> output_shape_vec;
if (A.dim() == 2)
{
output_shape_vec = {static_cast<int64_t>(M), static_cast<int64_t>(N_orig)};
}
else
{
output_shape_vec.reserve(A.dim());
for (int i = 0; i < A.dim() - 1; ++i)
output_shape_vec.push_back(A.size(i));
output_shape_vec.push_back(N_orig);
}
// Set output dtype based on activation dtype
torch::ScalarType output_dtype;
if (mActivationDtype == at::ScalarType::Half)
{
output_dtype = torch::kFloat16;
}
else if (mActivationDtype == at::ScalarType::BFloat16)
{
output_dtype = torch::kBFloat16;
}
else
{
TORCH_CHECK(false, "Unsupported activation type for output dtype determination");
}
torch::Tensor C_tensor = torch::empty(output_shape_vec, A.options().dtype(output_dtype));
void const* A_ptr = A.data_ptr();
TORCH_CHECK(B_packed.is_contiguous(), "B_packed tensor must be contiguous");
void const* B_ptr = B_packed.data_ptr();
void const* scales_ptr = scales.data_ptr();
void* C_ptr = C_tensor.data_ptr();
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config_to_use;
if (configIdx >= 0 && configIdx < getNumConfigs())
{
gemm_config_to_use = mConfigs.at(configIdx);
}
else
{
gemm_config_to_use = mConfigs.at(0);
}
size_t workspace_bytes = mGemmRunner->getWorkspaceSize(M, N_orig, K);
torch::Tensor workspace_tensor = torch::empty(
{static_cast<int64_t>(workspace_bytes)}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
char* workspace_ptr = nullptr;
if (workspace_bytes > 0)
{
workspace_ptr = reinterpret_cast<char*>(workspace_tensor.data_ptr());
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index());
mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr,
1.0f, // alpha
C_ptr, M, N_orig, K, group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream);
return C_tensor;
}
int64_t W4A16GemmRunner::getNumConfigs() const
{
TORCH_CHECK(mGemmRunner, "W4A16GemmRunner not initialized properly.");
return static_cast<int64_t>(mConfigs.size());
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.class_<torch_ext::W4A16GemmRunner>("W4A16GemmRunner")
.def(torch::init<at::ScalarType, int64_t>())
.def("run_gemm", &torch_ext::W4A16GemmRunner::runGemm)
.def("get_num_configs", &torch_ext::W4A16GemmRunner::getNumConfigs);
}

View File

@ -0,0 +1,44 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include <torch/extension.h>
namespace torch_ext
{
class W4A16GemmRunner : public torch::CustomClassHolder
{
public:
explicit W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode = 0);
at::Tensor runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales,
int64_t group_size_long, int64_t configIdx = -1, std::optional<at::Tensor> bias = std::nullopt,
std::optional<at::Tensor> zeros = std::nullopt) const;
int64_t getNumConfigs() const;
private:
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface> mGemmRunner;
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mConfigs;
at::ScalarType mActivationDtype;
};
} // namespace torch_ext

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
import torch
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._utils import get_sm_version
from ..attention_backend.interface import AttentionInputType
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
@ -579,6 +580,84 @@ def _(
dtype=output_dtype)
class W4A16GemmRunner(TunableRunner):
_runner_dict = dict()
MAX_SUPPORTED_SM_VERSION = 90
def __init__(self, activation_dtype: torch.dtype, quant_mode: int):
instance_key = (activation_dtype, quant_mode)
if instance_key not in W4A16GemmRunner._runner_dict:
W4A16GemmRunner._runner_dict[
instance_key] = torch.classes.trtllm.W4A16GemmRunner(
activation_dtype, quant_mode)
self._w4a16_gemm_runner = W4A16GemmRunner._runner_dict[instance_key]
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return list(range(self._w4a16_gemm_runner.get_num_configs()))
def forward(self,
inputs: List[torch.Tensor],
tactic: int = -1,
do_preparation: bool = False,
**kwargs) -> torch.Tensor:
if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION:
raise ValueError(
f"SM version {get_sm_version()} is not supported for W4A16 GEMM"
)
activation, weights_packed, scales = inputs
return self._w4a16_gemm_runner.run_gemm(
activation,
weights_packed,
scales,
kwargs["group_size"],
tactic,
kwargs["bias"],
kwargs["zeros"],
)
@torch.library.custom_op("trtllm::w4a16_gemm", mutates_args=())
def w4a16_gemm(input: torch.Tensor,
weight: torch.Tensor,
scales: torch.Tensor,
group_size: int,
has_zero_point: bool,
bias: Optional[torch.Tensor] = None,
zeros: Optional[torch.Tensor] = None) -> torch.Tensor:
assert not has_zero_point or zeros is not None, "Expected 'zeros' tensor when has_zero_point is True"
tuner = AutoTuner.get()
tuning_config = TuningConfig(dynamic_tensor_specs=(
# For tensor index 0 (input A), tune dimension 0 (M dimension)
DynamicTensorSpec(0, 0, (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32,
16, 8, 4, 2, 1), last_positive_power_of_2), ))
# NOTE: qunant_mode equals 0 it means we use scale only (FINEGRAINED_SCALE_ONLY), zeros is not used, else we use scale and zero point
quant_mode = 1 if has_zero_point else 0
if quant_mode == 0:
assert zeros is None, "When quant_mode is 0 (FINEGRAINED_SCALE_ONLY), zeros must be None"
w4a16_gemm_runner = W4A16GemmRunner(input.dtype, quant_mode)
kwargs = {"group_size": group_size, "zeros": zeros, "bias": bias}
_, best_tactic = tuner.choose_one("trtllm::w4a16_gemm::gemm",
[w4a16_gemm_runner], tuning_config,
[input, weight, scales], **kwargs)
return w4a16_gemm_runner(inputs=[input, weight, scales],
tactic=best_tactic,
**kwargs)
@torch.library.custom_op("trtllm::attention", mutates_args=())
def attention(
q: torch.Tensor,

View File

@ -16,6 +16,9 @@ from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.mode import QuantAlgo
from ...models.modeling_utils import QuantConfig
from ..utils import Fp4QuantizedTensor
@ -106,6 +109,14 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]):
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
if module.has_w4a16_awq:
# NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm
# we need to cast the weight to int8 first.
weight = preprocess_weights_for_mixed_gemm(
weight.T.to(torch.int8).contiguous().cpu(), torch.quint4x2,
torch.float16).cuda().contiguous()
copy_weight(module.weight, weight)
if module.bias is not None:
@ -194,7 +205,7 @@ class LinearMethodBase(ABC):
"""
@abstractmethod
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
"""
Load weights for the VANILLA weight mode.
"""
@ -202,7 +213,7 @@ class LinearMethodBase(ABC):
@abstractmethod
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
"""
Load weights for the FUSED_QKV_LINEAR weight mode.
"""
@ -210,7 +221,7 @@ class LinearMethodBase(ABC):
@abstractmethod
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
"""
Load weights for the FUSED_GATE_UP_LINEAR weight mode.
"""
@ -242,18 +253,18 @@ class UnquantizedLinearMethod(LinearMethodBase):
output = F.linear(input, module.weight, bias)
return output
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
@ -321,7 +332,7 @@ class FP8QDQLinearMethod(LinearMethodBase):
weight_scale.append(w["weight_scale"][...].reshape([]))
return input_scale, weight_scale
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
@ -335,7 +346,7 @@ class FP8QDQLinearMethod(LinearMethodBase):
copy_weight(module.weight_scale, weight_scale[0])
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
@ -358,7 +369,7 @@ class FP8QDQLinearMethod(LinearMethodBase):
copy_weight(module.weight, fused_weight)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
@ -428,7 +439,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
scale_name = "weight_scale"
return scale_name
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
scale_name = self._get_scale_name(weights)
@ -440,7 +451,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
@ -457,7 +468,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
copy_weight(module.weight_scale, fused_fp8_block_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
@ -566,7 +577,7 @@ class NVFP4LinearMethod(LinearMethodBase):
return input_scale, weight_scale, alpha
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
input_scale, weight_scale, alpha = self.load_weight_scales(
@ -588,7 +599,7 @@ class NVFP4LinearMethod(LinearMethodBase):
copy_weight(module.alpha, alpha)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
@ -609,7 +620,7 @@ class NVFP4LinearMethod(LinearMethodBase):
copy_weight(module.weight, fused_weight)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
@ -696,7 +707,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
return weight_scale
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
weight_scale = self.load_weight_scales(weights,
@ -711,7 +722,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
copy_weight(module.weight_scale, weight_scale)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
@ -727,7 +738,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
copy_weight(module.weight_scale, weight_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
@ -744,6 +755,141 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
copy_weight(module.weight_scale, weight_scale)
class W4A16_AWQ_LinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool,
dtype: torch.dtype) -> None:
# Quantized weights
module.weight = Parameter(torch.empty(
(in_features, out_features // 2),
dtype=torch.int8,
),
requires_grad=False)
group_size = module.quant_config.group_size
if in_features % group_size != 0:
raise ValueError(
f"in_features ({self.in_features}) must be divisible by group_size ({group_size}) "
f"for INT4 per-group quantization scale dimensions.")
module.weight_scale = Parameter(torch.empty(
(out_features, in_features // group_size), dtype=dtype),
requires_grad=False)
# NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the
# LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj
module.pre_quant_scale = None
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
if module.pre_quant_scale is not None:
pre_quant_scale = module.pre_quant_scale.repeat(input.shape[0], 1)
input = torch.mul(input, pre_quant_scale)
bias = bias.contiguous() if bias is not None else None
output = torch.ops.trtllm.w4a16_gemm(input.to(
module.dtype).contiguous(),
module.weight,
module.weight_scale.T.contiguous(),
module.quant_config.group_size,
module.quant_config.has_zero_point,
bias,
zeros=None)
return output
def load_weight_scales(
self,
weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]:
device = torch.device("cuda")
q_weight_scale = load_weight_shard(weights[0]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
k_weight_scale = load_weight_shard(weights[1]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
v_weight_scale = load_weight_shard(weights[2]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale]
return weight_scales
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
device = torch.device('cuda')
pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device)
module.pre_quant_scale = Parameter(
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=device)
weight_scale = load_weight_shard(weights[0]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device)
copy_weight(module.pre_quant_scale, pre_quant_scale)
copy_weight(module.weight_scale, weight_scale)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
torch.float16).cuda().contiguous()
copy_weight(module.weight, fused_weight)
weight_scales = self.load_weight_scales(weights)
# Create concatenated weight scale tensor
cat_weight_scale = torch.cat(weight_scales, dim=0)
copy_weight(module.weight_scale, cat_weight_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
device = torch.device('cuda')
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
torch.float16).cuda().contiguous()
copy_weight(module.weight, fused_weight)
left_scale = load_weight_shard(weights[0]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device).contiguous()
right_scale = load_weight_shard(weights[1]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device).contiguous()
fused_scale = torch.cat([left_scale, right_scale], dim=0)
copy_weight(module.weight_scale, fused_scale)
def get_quant_method(quant_config: Optional[QuantConfig] = None):
if quant_config is None or not quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
@ -756,6 +902,9 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
return NVFP4LinearMethod()
if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
return W4A8MXFP4FP8LinearMethod()
if quant_config.layer_quant_mode.is_int4_weight_only_per_group(
) and quant_config.quant_algo == QuantAlgo.W4A16_AWQ:
return W4A16_AWQ_LinearMethod()
raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}')
@ -859,6 +1008,12 @@ class Linear(nn.Module):
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
)
@property
def has_w4a16_awq(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
) and self.quant_config.quant_algo == QuantAlgo.W4A16_AWQ
def apply_linear(self,
input,
bias,

View File

@ -0,0 +1,94 @@
import pytest
import torch
from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul
import tensorrt_llm
from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner
from tensorrt_llm._utils import get_sm_version
@pytest.mark.parametrize(
"m, n, k, group_size, activation_dtype, has_pre_quant, has_zero, has_bias",
[
(3, 1024, 64, 64, torch.bfloat16, True, False, True),
(128, 1024, 256, 64, torch.bfloat16, True, False, True),
(192, 2048, 384, 64, torch.bfloat16, True, False, True),
(256, 2048, 1024, 64, torch.bfloat16, True, False, True),
(4, 1024, 128, 128, torch.bfloat16, True, False, True),
(64, 1024, 256, 128, torch.bfloat16, True, False, True),
(384, 2048, 384, 128, torch.bfloat16, True, False, True),
(512, 2048, 1024, 128, torch.bfloat16, True, False, True),
(4, 1024, 128, 128, torch.bfloat16, True, True, True),
(64, 1024, 256, 128, torch.bfloat16, True, True, True),
(384, 2048, 384, 128, torch.bfloat16, True, True, True),
(512, 2048, 1024, 128, torch.bfloat16, True, True, False),
(3, 1024, 64, 64, torch.float16, True, False, True),
(128, 1024, 256, 64, torch.float16, True, False, True),
(192, 2048, 384, 64, torch.float16, True, False, True),
(256, 2048, 1024, 64, torch.float16, True, False, True),
(4, 1024, 128, 128, torch.float16, True, False, True),
(64, 1024, 256, 128, torch.float16, True, False, True),
(384, 2048, 384, 128, torch.float16, True, False, True),
(512, 2048, 1024, 128, torch.float16, True, False, True),
(4, 1024, 128, 128, torch.float16, True, True, True),
(64, 1024, 256, 128, torch.float16, True, True, True),
(384, 2048, 384, 128, torch.float16, True, True, True),
(512, 2048, 1024, 128, torch.float16, True, True, False),
])
def test_matmul_activation_int4_input(m, n, k, group_size, activation_dtype,
has_pre_quant, has_zero, has_bias):
torch.manual_seed(0)
device = "cuda"
if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION:
pytest.skip(f"W4A16 not supported for SM version {get_sm_version()}")
total_groups = (k + group_size - 1) // group_size
activation = torch.randn(m, k, dtype=activation_dtype, device=device)
scale = torch.rand(total_groups, n, dtype=activation_dtype, device=device)
zero = torch.randn(total_groups, n, dtype=activation_dtype,
device=device) if has_zero else None
pre_quant_scale = torch.rand(1, k, dtype=activation_dtype, device=device)
bias = torch.randn(1, n, dtype=activation_dtype,
device=device) if has_bias else None
num_weights_in_32_bits = 8 # for torch.quint4x2
unprocessed_int_weight = torch.randint(-2**31,
2**31,
(k, n // num_weights_in_32_bits),
dtype=torch.int32,
device=device)
unprocessed_weight = unprocessed_int_weight.view(torch.int8)
# Ref quantized weights
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
ref_q_weight = unpacker(unprocessed_weight.cpu()).contiguous().cuda()
cuda_q_weight = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm(
unprocessed_weight.cpu(), torch.quint4x2,
activation_dtype).cuda().contiguous()
scale_ref = scale.repeat_interleave(group_size, dim=0)[:k, :]
ref_th_weight = ref_q_weight.to(activation_dtype) * scale_ref
if has_zero:
zero_ref = zero.repeat_interleave(group_size, dim=0)[:k, :]
ref_th_weight += zero_ref
if has_pre_quant:
pre_quant_scale = pre_quant_scale.repeat(m, 1)
activation = torch.mul(activation, pre_quant_scale)
output = torch.ops.trtllm.w4a16_gemm(
activation.contiguous(),
cuda_q_weight,
scale.contiguous(),
group_size,
has_zero,
bias.contiguous() if has_bias else None,
zeros=zero)
ref = woq_groupwise_gt_matmul(activation,
ref_th_weight.to(activation_dtype), bias)
woq_assert_near_eq(ref, output, 2)

View File

@ -0,0 +1,83 @@
import pytest
import torch
import tensorrt_llm.quantization.functional
from tensorrt_llm._torch.autotuner import autotune
from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
@pytest.mark.parametrize("weights_dtype", [torch.uint8])
@pytest.mark.parametrize(
"dtype",
[torch.float16],
)
def test_w4a16_linear(dtype, weights_dtype, has_zero=False):
if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION:
pytest.skip(
f"W4A116 is not supported in this SM version {get_sm_version()}")
SEQ_LEN = 10
HIDDEN_SIZE = 128
GROUP_SIZE = 128
torch.manual_seed(0)
total_groups = (HIDDEN_SIZE + GROUP_SIZE - 1) // GROUP_SIZE
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
w = torch.randint(0,
2**32 - 1, (HIDDEN_SIZE, HIDDEN_SIZE // 8),
dtype=torch.uint32,
device=x.device)
w = w.view(weights_dtype)
pre_quant_scale = torch.rand(HIDDEN_SIZE, dtype=dtype).cuda()
weight_scale = torch.rand(total_groups, HIDDEN_SIZE,
dtype=torch.float32).cuda()
bias = torch.randn(HIDDEN_SIZE, dtype=dtype).cuda().contiguous()
qc = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ,
group_size=GROUP_SIZE,
has_zero_point=has_zero)
linear_w4a16 = Linear(in_features=HIDDEN_SIZE,
out_features=HIDDEN_SIZE,
bias=True,
dtype=dtype,
quant_config=qc)
linear_w4a16.load_weights([{
'pre_quant_scale': pre_quant_scale,
'weight': w.T,
'weight_scale': weight_scale.T,
'bias': bias
}])
linear_w4a16 = linear_w4a16.cuda()
w = w.to(torch.int8)
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
w = preprocessor(w.contiguous().cpu(), torch.quint4x2,
x.dtype).cuda().contiguous()
torch.testing.assert_close(linear_w4a16.weight, w)
with torch.inference_mode(), autotune():
output = linear_w4a16.forward(x)
# ref linear
with torch.inference_mode():
pre_quant_scale = pre_quant_scale.repeat(SEQ_LEN, 1)
x = torch.mul(x, pre_quant_scale)
output_ref = torch.ops.trtllm.w4a16_gemm(x.contiguous(),
w,
weight_scale.type(x.dtype),
GROUP_SIZE,
has_zero,
bias,
zeros=None)
torch.cuda.synchronize()
torch.testing.assert_close(output, output_ref)

View File

@ -377,3 +377,23 @@ def default_dtype(dtype: torch.dtype):
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(cur_default)
def woq_assert_near_eq(ref, act, wTypeId):
# match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp
if wTypeId == 1:
bits_in_type = 8
else:
bits_in_type = 4
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
max_val = torch.max(abs(ref)).item()
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7)
def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None):
ref = torch.matmul(mat1, ref_torch_weights)
if bias is not None:
ref += bias
return ref