mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: W4A16 GEMM (#4232)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
This commit is contained in:
parent
19c56f0374
commit
7a617ad1fe
@ -533,7 +533,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
|
||||
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
|
||||
int const min_stages = int8_configs_only ? 3 : 2;
|
||||
int const min_stages = (sm == 89) ? 3 : int8_configs_only ? 3 : 2;
|
||||
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
for (auto const& tile_config : tiles)
|
||||
{
|
||||
|
||||
@ -84,7 +84,8 @@ add_library(
|
||||
userbuffersTensor.cpp
|
||||
weightOnlyQuantOp.cpp
|
||||
mtpOp.cpp
|
||||
loraOp.cpp)
|
||||
loraOp.cpp
|
||||
finegrained_mixed_dtype_gemm_thop.cpp)
|
||||
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
|
||||
${Python3_LIBRARIES} ${SHARED_TARGET})
|
||||
|
||||
225
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp
Normal file
225
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp
Normal file
@ -0,0 +1,225 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "finegrained_mixed_dtype_gemm_thop.h"
|
||||
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(ENABLE_FP8) && defined(TRTLLM_CUDA_FP8_AVAILABLE)
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
W4A16GemmRunner::W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode)
|
||||
: mActivationDtype(activationDtype)
|
||||
{
|
||||
if (quant_mode == 0)
|
||||
{
|
||||
if (activationDtype == at::ScalarType::Half)
|
||||
{
|
||||
mGemmRunner = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, half, half>>();
|
||||
}
|
||||
else if (activationDtype == at::ScalarType::BFloat16)
|
||||
{
|
||||
mGemmRunner = std::make_shared<
|
||||
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>();
|
||||
}
|
||||
}
|
||||
else if (quant_mode == 1)
|
||||
{
|
||||
if (activationDtype == at::ScalarType::Half)
|
||||
{
|
||||
mGemmRunner = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
|
||||
}
|
||||
else if (activationDtype == at::ScalarType::BFloat16)
|
||||
{
|
||||
mGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, __nv_bfloat16,
|
||||
__nv_bfloat16, __nv_bfloat16>>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported quant mode for W4A16GemmRunner: ", quant_mode);
|
||||
}
|
||||
|
||||
TORCH_CHECK(mGemmRunner, "Failed to create W4A16 GEMM runner for activation type ", c10::toString(activationDtype));
|
||||
mConfigs = mGemmRunner->getConfigs(); // Get configs via the interface
|
||||
TORCH_CHECK(!mConfigs.empty(), "Failed to get CUTLASS configs for W4A16 GEMM with activation type ",
|
||||
c10::toString(activationDtype));
|
||||
}
|
||||
|
||||
at::Tensor W4A16GemmRunner::runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales,
|
||||
int64_t group_size_long, int64_t configIdx, std::optional<at::Tensor> bias, std::optional<at::Tensor> zeros) const
|
||||
{
|
||||
TORCH_CHECK(A.is_cuda() && B_packed.is_cuda() && scales.is_cuda(), "All input tensors must be on CUDA");
|
||||
TORCH_CHECK(A.scalar_type() == mActivationDtype, "Activation tensor A's dtype ", c10::toString(A.scalar_type()),
|
||||
" does not match runner's expected dtype ", c10::toString(mActivationDtype));
|
||||
TORCH_CHECK(B_packed.scalar_type() == torch::kQUInt4x2 || B_packed.scalar_type() == torch::kInt8
|
||||
|| B_packed.scalar_type() == torch::kUInt8,
|
||||
"B_packed must be quint4x2, int8, or uint8 (view of quantized data)");
|
||||
TORCH_CHECK(A.is_contiguous() && B_packed.is_contiguous() && scales.is_contiguous(),
|
||||
"All input tensors (A, B_packed, scales) must be contiguous");
|
||||
|
||||
void const* zeros_ptr = nullptr;
|
||||
if (zeros.has_value())
|
||||
{
|
||||
TORCH_CHECK(zeros.value().is_cuda(), "Zeros tensor must be on CUDA");
|
||||
TORCH_CHECK(zeros.value().scalar_type() == torch::kFloat16 || zeros.value().scalar_type() == torch::kBFloat16,
|
||||
"Zeros must be FP16 or BF16");
|
||||
TORCH_CHECK(zeros.value().is_contiguous(), "Zeros tensor must be contiguous");
|
||||
zeros_ptr = zeros.value().data_ptr();
|
||||
}
|
||||
|
||||
void const* bias_ptr = nullptr;
|
||||
if (bias.has_value())
|
||||
{
|
||||
TORCH_CHECK(bias.value().scalar_type() == torch::kFloat16 || bias.value().scalar_type() == torch::kBFloat16,
|
||||
"Bias must be FP16 or BF16");
|
||||
TORCH_CHECK(bias.value().is_contiguous(), "Bias tensor must be contiguous");
|
||||
bias_ptr = bias.value().data_ptr();
|
||||
}
|
||||
|
||||
int M = 0, K_act = 0;
|
||||
// Logic to determine M and K_act from A_tensor dimensions
|
||||
if (A.dim() == 2)
|
||||
{
|
||||
M = A.size(0);
|
||||
K_act = A.size(1);
|
||||
}
|
||||
else
|
||||
{ // A.dim() >= 3
|
||||
M = A.size(0);
|
||||
for (int i = 1; i < A.dim() - 1; ++i)
|
||||
M *= A.size(i);
|
||||
K_act = A.size(A.dim() - 1);
|
||||
}
|
||||
|
||||
// Assuming B_packed is [K_weights, N_packed_int4_pairs] or similar
|
||||
// K_weights should match K_act. N_orig is 2 * N_packed_int4_pairs
|
||||
int K_weights = B_packed.size(0);
|
||||
int N_packed_int4 = B_packed.size(1); // This is number of uint8_t elements, each holding two int4
|
||||
int N_orig = N_packed_int4 * 2; // N_orig is the original N dimension
|
||||
|
||||
TORCH_CHECK(K_act == K_weights, "K dimension mismatch: A.shape[-1]=", K_act, " vs B_packed.shape[0]=", K_weights);
|
||||
int K = K_act;
|
||||
int group_size = static_cast<int>(group_size_long);
|
||||
|
||||
std::vector<int64_t> output_shape_vec;
|
||||
if (A.dim() == 2)
|
||||
{
|
||||
output_shape_vec = {static_cast<int64_t>(M), static_cast<int64_t>(N_orig)};
|
||||
}
|
||||
else
|
||||
{
|
||||
output_shape_vec.reserve(A.dim());
|
||||
for (int i = 0; i < A.dim() - 1; ++i)
|
||||
output_shape_vec.push_back(A.size(i));
|
||||
output_shape_vec.push_back(N_orig);
|
||||
}
|
||||
|
||||
// Set output dtype based on activation dtype
|
||||
torch::ScalarType output_dtype;
|
||||
if (mActivationDtype == at::ScalarType::Half)
|
||||
{
|
||||
output_dtype = torch::kFloat16;
|
||||
}
|
||||
else if (mActivationDtype == at::ScalarType::BFloat16)
|
||||
{
|
||||
output_dtype = torch::kBFloat16;
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported activation type for output dtype determination");
|
||||
}
|
||||
|
||||
torch::Tensor C_tensor = torch::empty(output_shape_vec, A.options().dtype(output_dtype));
|
||||
|
||||
void const* A_ptr = A.data_ptr();
|
||||
|
||||
TORCH_CHECK(B_packed.is_contiguous(), "B_packed tensor must be contiguous");
|
||||
void const* B_ptr = B_packed.data_ptr();
|
||||
void const* scales_ptr = scales.data_ptr();
|
||||
void* C_ptr = C_tensor.data_ptr();
|
||||
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config_to_use;
|
||||
if (configIdx >= 0 && configIdx < getNumConfigs())
|
||||
{
|
||||
gemm_config_to_use = mConfigs.at(configIdx);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_config_to_use = mConfigs.at(0);
|
||||
}
|
||||
|
||||
size_t workspace_bytes = mGemmRunner->getWorkspaceSize(M, N_orig, K);
|
||||
torch::Tensor workspace_tensor = torch::empty(
|
||||
{static_cast<int64_t>(workspace_bytes)}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
|
||||
char* workspace_ptr = nullptr;
|
||||
if (workspace_bytes > 0)
|
||||
{
|
||||
workspace_ptr = reinterpret_cast<char*>(workspace_tensor.data_ptr());
|
||||
}
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index());
|
||||
|
||||
mGemmRunner->gemm(A_ptr, B_ptr, scales_ptr, zeros_ptr, bias_ptr,
|
||||
1.0f, // alpha
|
||||
C_ptr, M, N_orig, K, group_size, gemm_config_to_use, workspace_ptr, workspace_bytes, stream);
|
||||
|
||||
return C_tensor;
|
||||
}
|
||||
|
||||
int64_t W4A16GemmRunner::getNumConfigs() const
|
||||
{
|
||||
TORCH_CHECK(mGemmRunner, "W4A16GemmRunner not initialized properly.");
|
||||
return static_cast<int64_t>(mConfigs.size());
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.class_<torch_ext::W4A16GemmRunner>("W4A16GemmRunner")
|
||||
.def(torch::init<at::ScalarType, int64_t>())
|
||||
.def("run_gemm", &torch_ext::W4A16GemmRunner::runGemm)
|
||||
.def("get_num_configs", &torch_ext::W4A16GemmRunner::getNumConfigs);
|
||||
}
|
||||
44
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h
Normal file
44
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
class W4A16GemmRunner : public torch::CustomClassHolder
|
||||
{
|
||||
public:
|
||||
explicit W4A16GemmRunner(at::ScalarType activationDtype, int64_t quant_mode = 0);
|
||||
|
||||
at::Tensor runGemm(at::Tensor const& A, at::Tensor const& B_packed, at::Tensor const& scales,
|
||||
int64_t group_size_long, int64_t configIdx = -1, std::optional<at::Tensor> bias = std::nullopt,
|
||||
std::optional<at::Tensor> zeros = std::nullopt) const;
|
||||
|
||||
int64_t getNumConfigs() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface> mGemmRunner;
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mConfigs;
|
||||
at::ScalarType mActivationDtype;
|
||||
};
|
||||
|
||||
} // namespace torch_ext
|
||||
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
from ..attention_backend.interface import AttentionInputType
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
@ -579,6 +580,84 @@ def _(
|
||||
dtype=output_dtype)
|
||||
|
||||
|
||||
class W4A16GemmRunner(TunableRunner):
|
||||
_runner_dict = dict()
|
||||
MAX_SUPPORTED_SM_VERSION = 90
|
||||
|
||||
def __init__(self, activation_dtype: torch.dtype, quant_mode: int):
|
||||
instance_key = (activation_dtype, quant_mode)
|
||||
if instance_key not in W4A16GemmRunner._runner_dict:
|
||||
W4A16GemmRunner._runner_dict[
|
||||
instance_key] = torch.classes.trtllm.W4A16GemmRunner(
|
||||
activation_dtype, quant_mode)
|
||||
self._w4a16_gemm_runner = W4A16GemmRunner._runner_dict[instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
return list(range(self._w4a16_gemm_runner.get_num_configs()))
|
||||
|
||||
def forward(self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic: int = -1,
|
||||
do_preparation: bool = False,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION:
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for W4A16 GEMM"
|
||||
)
|
||||
|
||||
activation, weights_packed, scales = inputs
|
||||
|
||||
return self._w4a16_gemm_runner.run_gemm(
|
||||
activation,
|
||||
weights_packed,
|
||||
scales,
|
||||
kwargs["group_size"],
|
||||
tactic,
|
||||
kwargs["bias"],
|
||||
kwargs["zeros"],
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::w4a16_gemm", mutates_args=())
|
||||
def w4a16_gemm(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
group_size: int,
|
||||
has_zero_point: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
zeros: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
assert not has_zero_point or zeros is not None, "Expected 'zeros' tensor when has_zero_point is True"
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(
|
||||
# For tensor index 0 (input A), tune dimension 0 (M dimension)
|
||||
DynamicTensorSpec(0, 0, (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32,
|
||||
16, 8, 4, 2, 1), last_positive_power_of_2), ))
|
||||
|
||||
# NOTE: qunant_mode equals 0 it means we use scale only (FINEGRAINED_SCALE_ONLY), zeros is not used, else we use scale and zero point
|
||||
quant_mode = 1 if has_zero_point else 0
|
||||
if quant_mode == 0:
|
||||
assert zeros is None, "When quant_mode is 0 (FINEGRAINED_SCALE_ONLY), zeros must be None"
|
||||
|
||||
w4a16_gemm_runner = W4A16GemmRunner(input.dtype, quant_mode)
|
||||
|
||||
kwargs = {"group_size": group_size, "zeros": zeros, "bias": bias}
|
||||
_, best_tactic = tuner.choose_one("trtllm::w4a16_gemm::gemm",
|
||||
[w4a16_gemm_runner], tuning_config,
|
||||
[input, weight, scales], **kwargs)
|
||||
|
||||
return w4a16_gemm_runner(inputs=[input, weight, scales],
|
||||
tactic=best_tactic,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::attention", mutates_args=())
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
|
||||
@ -16,6 +16,9 @@ from tensorrt_llm._torch.peft.lora.layer import LoraLayer
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy)
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.quantization.functional import \
|
||||
preprocess_weights_for_mixed_gemm
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
@ -106,6 +109,14 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict]):
|
||||
|
||||
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
|
||||
module.tp_rank, module.tp_mode, device)
|
||||
|
||||
if module.has_w4a16_awq:
|
||||
# NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm
|
||||
# we need to cast the weight to int8 first.
|
||||
weight = preprocess_weights_for_mixed_gemm(
|
||||
weight.T.to(torch.int8).contiguous().cpu(), torch.quint4x2,
|
||||
torch.float16).cuda().contiguous()
|
||||
|
||||
copy_weight(module.weight, weight)
|
||||
|
||||
if module.bias is not None:
|
||||
@ -194,7 +205,7 @@ class LinearMethodBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
"""
|
||||
Load weights for the VANILLA weight mode.
|
||||
"""
|
||||
@ -202,7 +213,7 @@ class LinearMethodBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
"""
|
||||
Load weights for the FUSED_QKV_LINEAR weight mode.
|
||||
"""
|
||||
@ -210,7 +221,7 @@ class LinearMethodBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
"""
|
||||
Load weights for the FUSED_GATE_UP_LINEAR weight mode.
|
||||
"""
|
||||
@ -242,18 +253,18 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
output = F.linear(input, module.weight, bias)
|
||||
return output
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
@ -321,7 +332,7 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
weight_scale.append(w["weight_scale"][...].reshape([]))
|
||||
return input_scale, weight_scale
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
input_scale, weight_scale = self.load_weight_scales(weights)
|
||||
if len(input_scale) != 0:
|
||||
@ -335,7 +346,7 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight_scale, weight_scale[0])
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
|
||||
@ -358,7 +369,7 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
input_scale, weight_scale = self.load_weight_scales(weights)
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
@ -428,7 +439,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
scale_name = "weight_scale"
|
||||
return scale_name
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
@ -440,7 +451,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
module.inv_input_scale.data = 1.0 / module.input_scale
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
@ -457,7 +468,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight_scale, fused_fp8_block_scale)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
@ -566,7 +577,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
|
||||
return input_scale, weight_scale, alpha
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
input_scale, weight_scale, alpha = self.load_weight_scales(
|
||||
@ -588,7 +599,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
copy_weight(module.alpha, alpha)
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
|
||||
@ -609,7 +620,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
@ -696,7 +707,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
|
||||
return weight_scale
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
weight_scale = self.load_weight_scales(weights,
|
||||
@ -711,7 +722,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
@ -727,7 +738,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
weights: List[Dict]) -> None:
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
@ -744,6 +755,141 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
|
||||
class W4A16_AWQ_LinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, module: Linear, in_features: int,
|
||||
out_features: int, bias: bool,
|
||||
dtype: torch.dtype) -> None:
|
||||
# Quantized weights
|
||||
module.weight = Parameter(torch.empty(
|
||||
(in_features, out_features // 2),
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False)
|
||||
|
||||
group_size = module.quant_config.group_size
|
||||
if in_features % group_size != 0:
|
||||
raise ValueError(
|
||||
f"in_features ({self.in_features}) must be divisible by group_size ({group_size}) "
|
||||
f"for INT4 per-group quantization scale dimensions.")
|
||||
|
||||
module.weight_scale = Parameter(torch.empty(
|
||||
(out_features, in_features // group_size), dtype=dtype),
|
||||
requires_grad=False)
|
||||
# NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the
|
||||
# LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj
|
||||
module.pre_quant_scale = None
|
||||
|
||||
if bias:
|
||||
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
|
||||
requires_grad=False)
|
||||
else:
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
if module.pre_quant_scale is not None:
|
||||
pre_quant_scale = module.pre_quant_scale.repeat(input.shape[0], 1)
|
||||
input = torch.mul(input, pre_quant_scale)
|
||||
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
output = torch.ops.trtllm.w4a16_gemm(input.to(
|
||||
module.dtype).contiguous(),
|
||||
module.weight,
|
||||
module.weight_scale.T.contiguous(),
|
||||
module.quant_config.group_size,
|
||||
module.quant_config.has_zero_point,
|
||||
bias,
|
||||
zeros=None)
|
||||
return output
|
||||
|
||||
def load_weight_scales(
|
||||
self,
|
||||
weights: List[Dict],
|
||||
tp_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]:
|
||||
device = torch.device("cuda")
|
||||
q_weight_scale = load_weight_shard(weights[0]['weight_scale'],
|
||||
tp_size,
|
||||
tp_rank,
|
||||
tp_mode,
|
||||
device=device)
|
||||
k_weight_scale = load_weight_shard(weights[1]['weight_scale'],
|
||||
tp_size,
|
||||
tp_rank,
|
||||
tp_mode,
|
||||
device=device)
|
||||
v_weight_scale = load_weight_shard(weights[2]['weight_scale'],
|
||||
tp_size,
|
||||
tp_rank,
|
||||
tp_mode,
|
||||
device=device)
|
||||
weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale]
|
||||
|
||||
return weight_scales
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
device = torch.device('cuda')
|
||||
pre_quant_scale = load_weight_shard(weights[0]['pre_quant_scale'],
|
||||
module.tp_size, module.tp_rank,
|
||||
module.tp_mode, device)
|
||||
module.pre_quant_scale = Parameter(
|
||||
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
|
||||
requires_grad=False).to(device=device)
|
||||
|
||||
weight_scale = load_weight_shard(weights[0]['weight_scale'],
|
||||
module.tp_size, module.tp_rank,
|
||||
module.tp_mode, device)
|
||||
|
||||
copy_weight(module.pre_quant_scale, pre_quant_scale)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
fused_weight = preprocess_weights_for_mixed_gemm(
|
||||
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
|
||||
torch.float16).cuda().contiguous()
|
||||
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
weight_scales = self.load_weight_scales(weights)
|
||||
|
||||
# Create concatenated weight scale tensor
|
||||
cat_weight_scale = torch.cat(weight_scales, dim=0)
|
||||
copy_weight(module.weight_scale, cat_weight_scale)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]) -> None:
|
||||
device = torch.device('cuda')
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
fused_weight = preprocess_weights_for_mixed_gemm(
|
||||
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
|
||||
torch.float16).cuda().contiguous()
|
||||
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
left_scale = load_weight_shard(weights[0]['weight_scale'],
|
||||
module.tp_size, module.tp_rank,
|
||||
module.tp_mode, device).contiguous()
|
||||
right_scale = load_weight_shard(weights[1]['weight_scale'],
|
||||
module.tp_size, module.tp_rank,
|
||||
module.tp_mode, device).contiguous()
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0)
|
||||
copy_weight(module.weight_scale, fused_scale)
|
||||
|
||||
|
||||
def get_quant_method(quant_config: Optional[QuantConfig] = None):
|
||||
if quant_config is None or not quant_config.layer_quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
@ -756,6 +902,9 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
|
||||
return NVFP4LinearMethod()
|
||||
if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
|
||||
return W4A8MXFP4FP8LinearMethod()
|
||||
if quant_config.layer_quant_mode.is_int4_weight_only_per_group(
|
||||
) and quant_config.quant_algo == QuantAlgo.W4A16_AWQ:
|
||||
return W4A16_AWQ_LinearMethod()
|
||||
raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}')
|
||||
|
||||
|
||||
@ -859,6 +1008,12 @@ class Linear(nn.Module):
|
||||
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
|
||||
)
|
||||
|
||||
@property
|
||||
def has_w4a16_awq(self):
|
||||
assert self._weights_created
|
||||
return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
|
||||
) and self.quant_config.quant_algo == QuantAlgo.W4A16_AWQ
|
||||
|
||||
def apply_linear(self,
|
||||
input,
|
||||
bias,
|
||||
|
||||
94
tests/unittest/_torch/thop/test_w4a16_gemm.py
Normal file
94
tests/unittest/_torch/thop/test_w4a16_gemm.py
Normal file
@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
import torch
|
||||
from utils.util import woq_assert_near_eq, woq_groupwise_gt_matmul
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"m, n, k, group_size, activation_dtype, has_pre_quant, has_zero, has_bias",
|
||||
[
|
||||
(3, 1024, 64, 64, torch.bfloat16, True, False, True),
|
||||
(128, 1024, 256, 64, torch.bfloat16, True, False, True),
|
||||
(192, 2048, 384, 64, torch.bfloat16, True, False, True),
|
||||
(256, 2048, 1024, 64, torch.bfloat16, True, False, True),
|
||||
(4, 1024, 128, 128, torch.bfloat16, True, False, True),
|
||||
(64, 1024, 256, 128, torch.bfloat16, True, False, True),
|
||||
(384, 2048, 384, 128, torch.bfloat16, True, False, True),
|
||||
(512, 2048, 1024, 128, torch.bfloat16, True, False, True),
|
||||
(4, 1024, 128, 128, torch.bfloat16, True, True, True),
|
||||
(64, 1024, 256, 128, torch.bfloat16, True, True, True),
|
||||
(384, 2048, 384, 128, torch.bfloat16, True, True, True),
|
||||
(512, 2048, 1024, 128, torch.bfloat16, True, True, False),
|
||||
(3, 1024, 64, 64, torch.float16, True, False, True),
|
||||
(128, 1024, 256, 64, torch.float16, True, False, True),
|
||||
(192, 2048, 384, 64, torch.float16, True, False, True),
|
||||
(256, 2048, 1024, 64, torch.float16, True, False, True),
|
||||
(4, 1024, 128, 128, torch.float16, True, False, True),
|
||||
(64, 1024, 256, 128, torch.float16, True, False, True),
|
||||
(384, 2048, 384, 128, torch.float16, True, False, True),
|
||||
(512, 2048, 1024, 128, torch.float16, True, False, True),
|
||||
(4, 1024, 128, 128, torch.float16, True, True, True),
|
||||
(64, 1024, 256, 128, torch.float16, True, True, True),
|
||||
(384, 2048, 384, 128, torch.float16, True, True, True),
|
||||
(512, 2048, 1024, 128, torch.float16, True, True, False),
|
||||
])
|
||||
def test_matmul_activation_int4_input(m, n, k, group_size, activation_dtype,
|
||||
has_pre_quant, has_zero, has_bias):
|
||||
torch.manual_seed(0)
|
||||
device = "cuda"
|
||||
|
||||
if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION:
|
||||
pytest.skip(f"W4A16 not supported for SM version {get_sm_version()}")
|
||||
|
||||
total_groups = (k + group_size - 1) // group_size
|
||||
activation = torch.randn(m, k, dtype=activation_dtype, device=device)
|
||||
scale = torch.rand(total_groups, n, dtype=activation_dtype, device=device)
|
||||
zero = torch.randn(total_groups, n, dtype=activation_dtype,
|
||||
device=device) if has_zero else None
|
||||
pre_quant_scale = torch.rand(1, k, dtype=activation_dtype, device=device)
|
||||
bias = torch.randn(1, n, dtype=activation_dtype,
|
||||
device=device) if has_bias else None
|
||||
|
||||
num_weights_in_32_bits = 8 # for torch.quint4x2
|
||||
unprocessed_int_weight = torch.randint(-2**31,
|
||||
2**31,
|
||||
(k, n // num_weights_in_32_bits),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
unprocessed_weight = unprocessed_int_weight.view(torch.int8)
|
||||
|
||||
# Ref quantized weights
|
||||
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
|
||||
ref_q_weight = unpacker(unprocessed_weight.cpu()).contiguous().cuda()
|
||||
|
||||
cuda_q_weight = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm(
|
||||
unprocessed_weight.cpu(), torch.quint4x2,
|
||||
activation_dtype).cuda().contiguous()
|
||||
|
||||
scale_ref = scale.repeat_interleave(group_size, dim=0)[:k, :]
|
||||
ref_th_weight = ref_q_weight.to(activation_dtype) * scale_ref
|
||||
|
||||
if has_zero:
|
||||
zero_ref = zero.repeat_interleave(group_size, dim=0)[:k, :]
|
||||
ref_th_weight += zero_ref
|
||||
|
||||
if has_pre_quant:
|
||||
pre_quant_scale = pre_quant_scale.repeat(m, 1)
|
||||
activation = torch.mul(activation, pre_quant_scale)
|
||||
|
||||
output = torch.ops.trtllm.w4a16_gemm(
|
||||
activation.contiguous(),
|
||||
cuda_q_weight,
|
||||
scale.contiguous(),
|
||||
group_size,
|
||||
has_zero,
|
||||
bias.contiguous() if has_bias else None,
|
||||
zeros=zero)
|
||||
|
||||
ref = woq_groupwise_gt_matmul(activation,
|
||||
ref_th_weight.to(activation_dtype), bias)
|
||||
|
||||
woq_assert_near_eq(ref, output, 2)
|
||||
83
tests/unittest/_torch/thop/test_w4a16_linear.py
Normal file
83
tests/unittest/_torch/thop/test_w4a16_linear.py
Normal file
@ -0,0 +1,83 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm.quantization.functional
|
||||
from tensorrt_llm._torch.autotuner import autotune
|
||||
from tensorrt_llm._torch.custom_ops.torch_custom_ops import W4A16GemmRunner
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize("weights_dtype", [torch.uint8])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.float16],
|
||||
)
|
||||
def test_w4a16_linear(dtype, weights_dtype, has_zero=False):
|
||||
|
||||
if get_sm_version() > W4A16GemmRunner.MAX_SUPPORTED_SM_VERSION:
|
||||
pytest.skip(
|
||||
f"W4A116 is not supported in this SM version {get_sm_version()}")
|
||||
|
||||
SEQ_LEN = 10
|
||||
HIDDEN_SIZE = 128
|
||||
GROUP_SIZE = 128
|
||||
torch.manual_seed(0)
|
||||
|
||||
total_groups = (HIDDEN_SIZE + GROUP_SIZE - 1) // GROUP_SIZE
|
||||
|
||||
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
||||
w = torch.randint(0,
|
||||
2**32 - 1, (HIDDEN_SIZE, HIDDEN_SIZE // 8),
|
||||
dtype=torch.uint32,
|
||||
device=x.device)
|
||||
w = w.view(weights_dtype)
|
||||
|
||||
pre_quant_scale = torch.rand(HIDDEN_SIZE, dtype=dtype).cuda()
|
||||
weight_scale = torch.rand(total_groups, HIDDEN_SIZE,
|
||||
dtype=torch.float32).cuda()
|
||||
bias = torch.randn(HIDDEN_SIZE, dtype=dtype).cuda().contiguous()
|
||||
|
||||
qc = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ,
|
||||
group_size=GROUP_SIZE,
|
||||
has_zero_point=has_zero)
|
||||
linear_w4a16 = Linear(in_features=HIDDEN_SIZE,
|
||||
out_features=HIDDEN_SIZE,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
quant_config=qc)
|
||||
|
||||
linear_w4a16.load_weights([{
|
||||
'pre_quant_scale': pre_quant_scale,
|
||||
'weight': w.T,
|
||||
'weight_scale': weight_scale.T,
|
||||
'bias': bias
|
||||
}])
|
||||
|
||||
linear_w4a16 = linear_w4a16.cuda()
|
||||
|
||||
w = w.to(torch.int8)
|
||||
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
|
||||
w = preprocessor(w.contiguous().cpu(), torch.quint4x2,
|
||||
x.dtype).cuda().contiguous()
|
||||
|
||||
torch.testing.assert_close(linear_w4a16.weight, w)
|
||||
|
||||
with torch.inference_mode(), autotune():
|
||||
output = linear_w4a16.forward(x)
|
||||
|
||||
# ref linear
|
||||
with torch.inference_mode():
|
||||
pre_quant_scale = pre_quant_scale.repeat(SEQ_LEN, 1)
|
||||
x = torch.mul(x, pre_quant_scale)
|
||||
|
||||
output_ref = torch.ops.trtllm.w4a16_gemm(x.contiguous(),
|
||||
w,
|
||||
weight_scale.type(x.dtype),
|
||||
GROUP_SIZE,
|
||||
has_zero,
|
||||
bias,
|
||||
zeros=None)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, output_ref)
|
||||
@ -377,3 +377,23 @@ def default_dtype(dtype: torch.dtype):
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(cur_default)
|
||||
|
||||
|
||||
def woq_assert_near_eq(ref, act, wTypeId):
|
||||
# match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp
|
||||
if wTypeId == 1:
|
||||
bits_in_type = 8
|
||||
else:
|
||||
bits_in_type = 4
|
||||
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
|
||||
|
||||
max_val = torch.max(abs(ref)).item()
|
||||
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
|
||||
torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7)
|
||||
|
||||
|
||||
def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None):
|
||||
ref = torch.matmul(mat1, ref_torch_weights)
|
||||
if bias is not None:
|
||||
ref += bias
|
||||
return ref
|
||||
|
||||
Loading…
Reference in New Issue
Block a user