diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 902d0f027c..dde7d0f275 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -533,7 +533,7 @@ std::vector get_candidate_configs( std::vector candidate_configs; bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; - int const min_stages = int8_configs_only ? 3 : 2; + int const min_stages = (sm == 89) ? 3 : int8_configs_only ? 3 : 2; int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); for (auto const& tile_config : tiles) { diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index bf6b2de375..d2c196c604 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -84,7 +84,8 @@ add_library( userbuffersTensor.cpp weightOnlyQuantOp.cpp mtpOp.cpp - loraOp.cpp) + loraOp.cpp + finegrained_mixed_dtype_gemm_thop.cpp) set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES} ${SHARED_TARGET}) diff --git a/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp new file mode 100644 index 0000000000..9fa36d16b8 --- /dev/null +++ b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "finegrained_mixed_dtype_gemm_thop.h" + +#include "cutlass_extensions/gemm_configs.h" +#include "cutlass_extensions/weight_only_quant_op.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "tensorrt_llm/runtime/torchUtils.h" + +#include +#include +#include + +#if defined(ENABLE_FP8) && defined(TRTLLM_CUDA_FP8_AVAILABLE) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch_ext +{ + +W4A16GemmRunner::W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode) + : mActivationDtype(activationDtype) +{ + if (quant_mode == 0) + { + if (activationDtype == at::ScalarType::Half) + { + mGemmRunner = std::make_shared>(); + } + else if (activationDtype == at::ScalarType::BFloat16) + { + mGemmRunner = std::make_shared< + tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>(); + } + } + else if (quant_mode == 1) + { + if (activationDtype == at::ScalarType::Half) + { + mGemmRunner = std::make_shared>(); + } + else if (activationDtype == at::ScalarType::BFloat16) + { + mGemmRunner + = std::make_shared>(); + } + } + else + { + TORCH_CHECK(false, "Unsupported quant mode for W4A16GemmRunner: ", quant_mode); + } + + TORCH_CHECK(mGemmRunner, "Failed to create W4A16 GEMM runner for activation type ", c10::toString(activationDtype)); + mConfigs = mGemmRunner->getConfigs(); // Get configs via the interface + TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for W4A16 GEMM with activation type ", + c10::toString(activationDtype)); +} + +at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales, + int64_t group_size_long, int64_t configIdx, std::optional bias, std::optional zeros) const +{ + TORCH_CHECK(A.is_cuda() && B_packed.is_cuda() && scales.is_cuda(), "All input tensors must be on CUDA"); + TORCH_CHECK(A.scalar_type() == mActivationDtype, "Activation tensor A's dtype ", c10::toString(A.scalar_type()), + " does not match runner's expected dtype ", c10::toString(mActivationDtype)); + TORCH_CHECK(B_packed.scalar_type() == torch::kQUInt4x2 || B_packed.scalar_type() == torch::kInt8 + || B_packed.scalar_type() == torch::kUInt8, + "B_packed must be quint4x2, int8, or uint8 (view of quantized data)"); + TORCH_CHECK(A.is_contiguous() && B_packed.is_contiguous() && scales.is_contiguous(), + "All input tensors (A, B_packed, scales) must be contiguous"); + + void const* zeros_ptr = nullptr; + if (zeros.has_value()) + { + TORCH_CHECK(zeros.value().is_cuda(), "Zeros tensor must be on CUDA"); + TORCH_CHECK(zeros.value().scalar_type() == torch::kFloat16 || zeros.value().scalar_type() == torch::kBFloat16, + "Zeros must be FP16 or BF16"); + TORCH_CHECK(zeros.value().is_contiguous(), "Zeros tensor must be contiguous"); + zeros_ptr = zeros.value().data_ptr(); + } + + void const* bias_ptr = nullptr; + if (bias.has_value()) + { + TORCH_CHECK(bias.value().scalar_type() == torch::kFloat16 || bias.value().scalar_type() == torch::kBFloat16, + "Bias must be FP16 or BF16"); + TORCH_CHECK(bias.value().is_contiguous(), "Bias tensor must be contiguous"); + bias_ptr = bias.value().data_ptr(); + } + + int M = 0, K_act = 0; + // Logic to determine M and K_act from A_tensor dimensions + if (A.dim() == 2) + { + M = A.size(0); + K_act = A.size(1); + } + else + { // A.dim() >= 3 + M = A.size(0); + for (int i = 1; i < A.dim() - 1; ++i) + M *= A.size(i); + K_act = A.size(A.dim() - 1); + } + + // Assuming B_packed is [K_weights, N_packed_int4_pairs] or similar + // K_weights should match K_act. N_orig is 2 * N_packed_int4_pairs + int K_weights = B_packed.size(0); + int N_packed_int4 = B_packed.size(1); // This is number of uint8_t elements, each holding two int4 + int N_orig = N_packed_int4 * 2; // N_orig is the original N dimension + + TORCH_CHECK(K_act == K_weights, "K dimension mismatch: A.shape[-1]=", K_act, " vs B_packed.shape[0]=", K_weights); + int K = K_act; + int group_size = static_cast(group_size_long); + + std::vector output_shape_vec; + if (A.dim() == 2) + { + output_shape_vec = {static_cast(M), static_cast(N_orig)}; + } + else + { + output_shape_vec.reserve(A.dim()); + for (int i = 0; i < A.dim() - 1; ++i) + output_shape_vec.push_back(A.size(i)); + output_shape_vec.push_back(N_orig); + } + + // Set output dtype based on activation dtype + torch::ScalarType output_dtype; + if (mActivationDtype == at::ScalarType::Half) + { + output_dtype = torch::kFloat16; + } + else if (mActivationDtype == at::ScalarType::BFloat16) + { + output_dtype = torch::kBFloat16; + } + else + { + TORCH_CHECK(false, "Unsupported activation type for output dtype determination"); + } + + torch::Tensor C_tensor = torch::empty(output_shape_vec, A.options().dtype(output_dtype)); + + void const* A_ptr = A.data_ptr(); + + TORCH_CHECK(B_packed.is_contiguous(), "B_packed tensor must be contiguous"); + void const* B_ptr = B_packed.data_ptr(); + void const* scales_ptr = scales.data_ptr(); + void* C_ptr = C_tensor.data_ptr(); + + tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config_to_use; + if (configIdx >= 0 && configIdx < getNumConfigs()) + { + gemm_config_to_use = mConfigs.at(configIdx); + } + else + { + gemm_config_to_use = mConfigs.at(0); + } + + size_t workspace_bytes = mGemmRunner->getWorkspaceSize(M, N_orig, K); + torch::Tensor workspace_tensor = torch::empty( + {static_cast(workspace_bytes)}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); + char* workspace_ptr = nullptr; + if (workspace_bytes > 0) + { + workspace_ptr = reinterpret_cast(workspace_tensor.data_ptr()); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()); + + mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr, + 1.0f, // alpha + C_ptr, M, N_orig, K, group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream); + + return C_tensor; +} + +int64_t W4A16GemmRunner::getNumConfigs() const +{ + TORCH_CHECK(mGemmRunner, "W4A16GemmRunner not initialized properly."); + return static_cast(mConfigs.size()); +} + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.class_("W4A16GemmRunner") + .def(torch::init()) + .def("run_gemm", &torch_ext::W4A16GemmRunner::runGemm) + .def("get_num_configs", &torch_ext::W4A16GemmRunner::getNumConfigs); +} diff --git a/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h new file mode 100644 index 0000000000..1b2083de5a --- /dev/null +++ b/cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/gemm_configs.h" +#include "cutlass_extensions/weight_only_quant_op.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include + +namespace torch_ext +{ + +class W4A16GemmRunner : public torch::CustomClassHolder +{ +public: + explicit W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode = 0); + + at::Tensor runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales, + int64_t group_size_long, int64_t configIdx = -1, std::optional bias = std::nullopt, + std::optional zeros = std::nullopt) const; + + int64_t getNumConfigs() const; + +private: + std::shared_ptr mGemmRunner; + std::vector mConfigs; + at::ScalarType mActivationDtype; +}; + +} // namespace torch_ext diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 6d1e8c05e5..f15f70111b 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple import torch import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils +from tensorrt_llm._utils import get_sm_version from ..attention_backend.interface import AttentionInputType from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, @@ -579,6 +580,84 @@ def _( dtype=output_dtype) +class W4A16GemmRunner(TunableRunner): + _runner_dict = dict() + MAX_SUPPORTED_SM_VERSION = 90 + + def __init__(self, activation_dtype: torch.dtype, quant_mode: int): + instance_key = (activation_dtype, quant_mode) + if instance_key not in W4A16GemmRunner._runner_dict: + W4A16GemmRunner._runner_dict[ + instance_key] = torch.classes.trtllm.W4A16GemmRunner( + activation_dtype, quant_mode) + self._w4a16_gemm_runner = W4A16GemmRunner._runner_dict[instance_key] + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(self._w4a16_gemm_runner.get_num_configs())) + + def forward(self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs) -> torch.Tensor: + + if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION: + raise ValueError( + f"SM version {get_sm_version()} is not supported for W4A16 GEMM" + ) + + activation, weights_packed, scales = inputs + + return self._w4a16_gemm_runner.run_gemm( + activation, + weights_packed, + scales, + kwargs["group_size"], + tactic, + kwargs["bias"], + kwargs["zeros"], + ) + + +@torch.library.custom_op("trtllm::w4a16_gemm", mutates_args=()) +def w4a16_gemm(input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + group_size: int, + has_zero_point: bool, + bias: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None) -> torch.Tensor: + + assert not has_zero_point or zeros is not None, "Expected 'zeros' tensor when has_zero_point is True" + + tuner = AutoTuner.get() + + tuning_config = TuningConfig(dynamic_tensor_specs=( + # For tensor index 0 (input A), tune dimension 0 (M dimension) + DynamicTensorSpec(0, 0, (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, + 16, 8, 4, 2, 1), last_positive_power_of_2), )) + + # NOTE: qunant_mode equals 0 it means we use scale only (FINEGRAINED_SCALE_ONLY), zeros is not used, else we use scale and zero point + quant_mode = 1 if has_zero_point else 0 + if quant_mode == 0: + assert zeros is None, "When quant_mode is 0 (FINEGRAINED_SCALE_ONLY), zeros must be None" + + w4a16_gemm_runner = W4A16GemmRunner(input.dtype, quant_mode) + + kwargs = {"group_size": group_size, "zeros": zeros, "bias": bias} + _, best_tactic = tuner.choose_one("trtllm::w4a16_gemm::gemm", + [w4a16_gemm_runner], tuning_config, + [input, weight, scales], **kwargs) + + return w4a16_gemm_runner(inputs=[input, weight, scales], + tactic=best_tactic, + **kwargs) + + @torch.library.custom_op("trtllm::attention", mutates_args=()) def attention( q: torch.Tensor, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 99cbde435f..db34b53384 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -16,6 +16,9 @@ from tensorrt_llm._torch.peft.lora.layer import LoraLayer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy) from tensorrt_llm.mapping import Mapping +from tensorrt_llm.quantization.functional import \ + preprocess_weights_for_mixed_gemm +from tensorrt_llm.quantization.mode import QuantAlgo from ...models.modeling_utils import QuantConfig from ..utils import Fp4QuantizedTensor @@ -106,6 +109,14 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]): weight = load_weight_shard(weights[0]['weight'], module.tp_size, module.tp_rank, module.tp_mode, device) + + if module.has_w4a16_awq: + # NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm + # we need to cast the weight to int8 first. + weight = preprocess_weights_for_mixed_gemm( + weight.T.to(torch.int8).contiguous().cpu(), torch.quint4x2, + torch.float16).cuda().contiguous() + copy_weight(module.weight, weight) if module.bias is not None: @@ -194,7 +205,7 @@ class LinearMethodBase(ABC): """ @abstractmethod - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: """ Load weights for the VANILLA weight mode. """ @@ -202,7 +213,7 @@ class LinearMethodBase(ABC): @abstractmethod def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: """ Load weights for the FUSED_QKV_LINEAR weight mode. """ @@ -210,7 +221,7 @@ class LinearMethodBase(ABC): @abstractmethod def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: """ Load weights for the FUSED_GATE_UP_LINEAR weight mode. """ @@ -242,18 +253,18 @@ class UnquantizedLinearMethod(LinearMethodBase): output = F.linear(input, module.weight, bias) return output - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) fused_weight = torch.cat((q_weight, k_weight, v_weight)) copy_weight(module.weight, fused_weight) def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) fused_weight = torch.cat((gate_weight, up_weight)) @@ -321,7 +332,7 @@ class FP8QDQLinearMethod(LinearMethodBase): weight_scale.append(w["weight_scale"][...].reshape([])) return input_scale, weight_scale - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) input_scale, weight_scale = self.load_weight_scales(weights) if len(input_scale) != 0: @@ -335,7 +346,7 @@ class FP8QDQLinearMethod(LinearMethodBase): copy_weight(module.weight_scale, weight_scale[0]) def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) @@ -358,7 +369,7 @@ class FP8QDQLinearMethod(LinearMethodBase): copy_weight(module.weight, fused_weight) def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: input_scale, weight_scale = self.load_weight_scales(weights) if len(input_scale) != 0: # Static quantization @@ -428,7 +439,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase): scale_name = "weight_scale" return scale_name - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) scale_name = self._get_scale_name(weights) @@ -440,7 +451,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase): module.inv_input_scale.data = 1.0 / module.input_scale def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) fused_weight = torch.cat((q_weight, k_weight, v_weight)) @@ -457,7 +468,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase): copy_weight(module.weight_scale, fused_fp8_block_scale) def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) fused_weight = torch.cat((gate_weight, up_weight)) @@ -566,7 +577,7 @@ class NVFP4LinearMethod(LinearMethodBase): return input_scale, weight_scale, alpha - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) input_scale, weight_scale, alpha = self.load_weight_scales( @@ -588,7 +599,7 @@ class NVFP4LinearMethod(LinearMethodBase): copy_weight(module.alpha, alpha) def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) @@ -609,7 +620,7 @@ class NVFP4LinearMethod(LinearMethodBase): copy_weight(module.weight, fused_weight) def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) fused_weight = torch.cat((gate_weight, up_weight)) @@ -696,7 +707,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): weight_scale.append(ws.view(fp4_utils.float4_sf_dtype)) return weight_scale - def load_weights_vanilla(self, module: Linear, weights: List[Dict]): + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) weight_scale = self.load_weight_scales(weights, @@ -711,7 +722,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): copy_weight(module.weight_scale, weight_scale) def load_weights_fused_qkv_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) fused_weight = torch.cat((q_weight, k_weight, v_weight)) @@ -727,7 +738,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): copy_weight(module.weight_scale, weight_scale) def load_weights_fused_gate_up_linear(self, module: Linear, - weights: List[Dict]): + weights: List[Dict]) -> None: gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) fused_weight = torch.cat((gate_weight, up_weight)) @@ -744,6 +755,141 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase): copy_weight(module.weight_scale, weight_scale) +class W4A16_AWQ_LinearMethod(LinearMethodBase): + + def create_weights(self, module: Linear, in_features: int, + out_features: int, bias: bool, + dtype: torch.dtype) -> None: + # Quantized weights + module.weight = Parameter(torch.empty( + (in_features, out_features // 2), + dtype=torch.int8, + ), + requires_grad=False) + + group_size = module.quant_config.group_size + if in_features % group_size != 0: + raise ValueError( + f"in_features ({self.in_features}) must be divisible by group_size ({group_size}) " + f"for INT4 per-group quantization scale dimensions.") + + module.weight_scale = Parameter(torch.empty( + (out_features, in_features // group_size), dtype=dtype), + requires_grad=False) + # NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the + # LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj + module.pre_quant_scale = None + + if bias: + module.bias = Parameter(torch.empty((out_features), dtype=dtype), + requires_grad=False) + else: + module.register_parameter("bias", None) + + def apply(self, module: Linear, input: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + if module.pre_quant_scale is not None: + pre_quant_scale = module.pre_quant_scale.repeat(input.shape[0], 1) + input = torch.mul(input, pre_quant_scale) + + bias = bias.contiguous() if bias is not None else None + + output = torch.ops.trtllm.w4a16_gemm(input.to( + module.dtype).contiguous(), + module.weight, + module.weight_scale.T.contiguous(), + module.quant_config.group_size, + module.quant_config.has_zero_point, + bias, + zeros=None) + return output + + def load_weight_scales( + self, + weights: List[Dict], + tp_size: int = 1, + tp_rank: int = 0, + tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]: + device = torch.device("cuda") + q_weight_scale = load_weight_shard(weights[0]['weight_scale'], + tp_size, + tp_rank, + tp_mode, + device=device) + k_weight_scale = load_weight_shard(weights[1]['weight_scale'], + tp_size, + tp_rank, + tp_mode, + device=device) + v_weight_scale = load_weight_shard(weights[2]['weight_scale'], + tp_size, + tp_rank, + tp_mode, + device=device) + weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale] + + return weight_scales + + def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: + load_weights_vanilla_helper(module, weights) + + device = torch.device('cuda') + pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'], + module.tp_size, module.tp_rank, + module.tp_mode, device) + module.pre_quant_scale = Parameter( + torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), + requires_grad=False).to(device=device) + + weight_scale = load_weight_shard(weights[0]['weight_scale'], + module.tp_size, module.tp_rank, + module.tp_mode, device) + + copy_weight(module.pre_quant_scale, pre_quant_scale) + copy_weight(module.weight_scale, weight_scale) + + def load_weights_fused_qkv_linear(self, module: Linear, + weights: List[Dict]) -> None: + q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( + module, weights) + + fused_weight = torch.cat((q_weight, k_weight, v_weight)) + fused_weight = preprocess_weights_for_mixed_gemm( + fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2, + torch.float16).cuda().contiguous() + + copy_weight(module.weight, fused_weight) + + weight_scales = self.load_weight_scales(weights) + + # Create concatenated weight scale tensor + cat_weight_scale = torch.cat(weight_scales, dim=0) + copy_weight(module.weight_scale, cat_weight_scale) + + def load_weights_fused_gate_up_linear(self, module: Linear, + weights: List[Dict]) -> None: + device = torch.device('cuda') + gate_weight, up_weight = load_weights_fused_gate_up_helper( + module, weights) + + fused_weight = torch.cat((gate_weight, up_weight)) + fused_weight = preprocess_weights_for_mixed_gemm( + fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2, + torch.float16).cuda().contiguous() + + copy_weight(module.weight, fused_weight) + + left_scale = load_weight_shard(weights[0]['weight_scale'], + module.tp_size, module.tp_rank, + module.tp_mode, device).contiguous() + right_scale = load_weight_shard(weights[1]['weight_scale'], + module.tp_size, module.tp_rank, + module.tp_mode, device).contiguous() + fused_scale = torch.cat([left_scale, right_scale], dim=0) + copy_weight(module.weight_scale, fused_scale) + + def get_quant_method(quant_config: Optional[QuantConfig] = None): if quant_config is None or not quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -756,6 +902,9 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None): return NVFP4LinearMethod() if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(): return W4A8MXFP4FP8LinearMethod() + if quant_config.layer_quant_mode.is_int4_weight_only_per_group( + ) and quant_config.quant_algo == QuantAlgo.W4A16_AWQ: + return W4A16_AWQ_LinearMethod() raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}') @@ -859,6 +1008,12 @@ class Linear(nn.Module): return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4( ) + @property + def has_w4a16_awq(self): + assert self._weights_created + return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( + ) and self.quant_config.quant_algo == QuantAlgo.W4A16_AWQ + def apply_linear(self, input, bias, diff --git a/tests/unittest/_torch/thop/test_w4a16_gemm.py b/tests/unittest/_torch/thop/test_w4a16_gemm.py new file mode 100644 index 0000000000..b3a034bd5d --- /dev/null +++ b/tests/unittest/_torch/thop/test_w4a16_gemm.py @@ -0,0 +1,94 @@ +import pytest +import torch +from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul + +import tensorrt_llm +from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner +from tensorrt_llm._utils import get_sm_version + + +@pytest.mark.parametrize( + "m, n, k, group_size, activation_dtype, has_pre_quant, has_zero, has_bias", + [ + (3, 1024, 64, 64, torch.bfloat16, True, False, True), + (128, 1024, 256, 64, torch.bfloat16, True, False, True), + (192, 2048, 384, 64, torch.bfloat16, True, False, True), + (256, 2048, 1024, 64, torch.bfloat16, True, False, True), + (4, 1024, 128, 128, torch.bfloat16, True, False, True), + (64, 1024, 256, 128, torch.bfloat16, True, False, True), + (384, 2048, 384, 128, torch.bfloat16, True, False, True), + (512, 2048, 1024, 128, torch.bfloat16, True, False, True), + (4, 1024, 128, 128, torch.bfloat16, True, True, True), + (64, 1024, 256, 128, torch.bfloat16, True, True, True), + (384, 2048, 384, 128, torch.bfloat16, True, True, True), + (512, 2048, 1024, 128, torch.bfloat16, True, True, False), + (3, 1024, 64, 64, torch.float16, True, False, True), + (128, 1024, 256, 64, torch.float16, True, False, True), + (192, 2048, 384, 64, torch.float16, True, False, True), + (256, 2048, 1024, 64, torch.float16, True, False, True), + (4, 1024, 128, 128, torch.float16, True, False, True), + (64, 1024, 256, 128, torch.float16, True, False, True), + (384, 2048, 384, 128, torch.float16, True, False, True), + (512, 2048, 1024, 128, torch.float16, True, False, True), + (4, 1024, 128, 128, torch.float16, True, True, True), + (64, 1024, 256, 128, torch.float16, True, True, True), + (384, 2048, 384, 128, torch.float16, True, True, True), + (512, 2048, 1024, 128, torch.float16, True, True, False), + ]) +def test_matmul_activation_int4_input(m, n, k, group_size, activation_dtype, + has_pre_quant, has_zero, has_bias): + torch.manual_seed(0) + device = "cuda" + + if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION: + pytest.skip(f"W4A16 not supported for SM version {get_sm_version()}") + + total_groups = (k + group_size - 1) // group_size + activation = torch.randn(m, k, dtype=activation_dtype, device=device) + scale = torch.rand(total_groups, n, dtype=activation_dtype, device=device) + zero = torch.randn(total_groups, n, dtype=activation_dtype, + device=device) if has_zero else None + pre_quant_scale = torch.rand(1, k, dtype=activation_dtype, device=device) + bias = torch.randn(1, n, dtype=activation_dtype, + device=device) if has_bias else None + + num_weights_in_32_bits = 8 # for torch.quint4x2 + unprocessed_int_weight = torch.randint(-2**31, + 2**31, + (k, n // num_weights_in_32_bits), + dtype=torch.int32, + device=device) + unprocessed_weight = unprocessed_int_weight.view(torch.int8) + + # Ref quantized weights + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + ref_q_weight = unpacker(unprocessed_weight.cpu()).contiguous().cuda() + + cuda_q_weight = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm( + unprocessed_weight.cpu(), torch.quint4x2, + activation_dtype).cuda().contiguous() + + scale_ref = scale.repeat_interleave(group_size, dim=0)[:k, :] + ref_th_weight = ref_q_weight.to(activation_dtype) * scale_ref + + if has_zero: + zero_ref = zero.repeat_interleave(group_size, dim=0)[:k, :] + ref_th_weight += zero_ref + + if has_pre_quant: + pre_quant_scale = pre_quant_scale.repeat(m, 1) + activation = torch.mul(activation, pre_quant_scale) + + output = torch.ops.trtllm.w4a16_gemm( + activation.contiguous(), + cuda_q_weight, + scale.contiguous(), + group_size, + has_zero, + bias.contiguous() if has_bias else None, + zeros=zero) + + ref = woq_groupwise_gt_matmul(activation, + ref_th_weight.to(activation_dtype), bias) + + woq_assert_near_eq(ref, output, 2) diff --git a/tests/unittest/_torch/thop/test_w4a16_linear.py b/tests/unittest/_torch/thop/test_w4a16_linear.py new file mode 100644 index 0000000000..1398acc297 --- /dev/null +++ b/tests/unittest/_torch/thop/test_w4a16_linear.py @@ -0,0 +1,83 @@ +import pytest +import torch + +import tensorrt_llm.quantization.functional +from tensorrt_llm._torch.autotuner import autotune +from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner +from tensorrt_llm._torch.modules.linear import Linear +from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + + +@pytest.mark.parametrize("weights_dtype", [torch.uint8]) +@pytest.mark.parametrize( + "dtype", + [torch.float16], +) +def test_w4a16_linear(dtype, weights_dtype, has_zero=False): + + if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION: + pytest.skip( + f"W4A116 is not supported in this SM version {get_sm_version()}") + + SEQ_LEN = 10 + HIDDEN_SIZE = 128 + GROUP_SIZE = 128 + torch.manual_seed(0) + + total_groups = (HIDDEN_SIZE + GROUP_SIZE - 1) // GROUP_SIZE + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + w = torch.randint(0, + 2**32 - 1, (HIDDEN_SIZE, HIDDEN_SIZE // 8), + dtype=torch.uint32, + device=x.device) + w = w.view(weights_dtype) + + pre_quant_scale = torch.rand(HIDDEN_SIZE, dtype=dtype).cuda() + weight_scale = torch.rand(total_groups, HIDDEN_SIZE, + dtype=torch.float32).cuda() + bias = torch.randn(HIDDEN_SIZE, dtype=dtype).cuda().contiguous() + + qc = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ, + group_size=GROUP_SIZE, + has_zero_point=has_zero) + linear_w4a16 = Linear(in_features=HIDDEN_SIZE, + out_features=HIDDEN_SIZE, + bias=True, + dtype=dtype, + quant_config=qc) + + linear_w4a16.load_weights([{ + 'pre_quant_scale': pre_quant_scale, + 'weight': w.T, + 'weight_scale': weight_scale.T, + 'bias': bias + }]) + + linear_w4a16 = linear_w4a16.cuda() + + w = w.to(torch.int8) + preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm + w = preprocessor(w.contiguous().cpu(), torch.quint4x2, + x.dtype).cuda().contiguous() + + torch.testing.assert_close(linear_w4a16.weight, w) + + with torch.inference_mode(), autotune(): + output = linear_w4a16.forward(x) + + # ref linear + with torch.inference_mode(): + pre_quant_scale = pre_quant_scale.repeat(SEQ_LEN, 1) + x = torch.mul(x, pre_quant_scale) + + output_ref = torch.ops.trtllm.w4a16_gemm(x.contiguous(), + w, + weight_scale.type(x.dtype), + GROUP_SIZE, + has_zero, + bias, + zeros=None) + torch.cuda.synchronize() + torch.testing.assert_close(output, output_ref) diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index e34370e9eb..72f205dc51 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -377,3 +377,23 @@ def default_dtype(dtype: torch.dtype): torch.set_default_dtype(dtype) yield torch.set_default_dtype(cur_default) + + +def woq_assert_near_eq(ref, act, wTypeId): + # match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp + if wTypeId == 1: + bits_in_type = 8 + else: + bits_in_type = 4 + quant_range_scale = 1.0 / float(1 << (bits_in_type - 1)) + + max_val = torch.max(abs(ref)).item() + atol = (max_val * quant_range_scale) * 1.5 # allow for rounding + torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7) + + +def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None): + ref = torch.matmul(mat1, ref_torch_weights) + if bias is not None: + ref += bias + return ref