mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* add eval_llmapi Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> tmp commit port to CLI tool Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> move Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> setup llmapi Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix spec_dec_algo Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> _update_from_hf_quant_config Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> migrate test_pytorch.py Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix fp8 block scales Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix fp8 rowwise Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> adj alpha Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> move test_pytorch.py cases Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> move Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> rename test_accuracy.py to test_cli.py Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> clean Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix cnn_dailymail Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * renaming to cli flow Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * rename MMLU Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * rename Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * add error Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --------- Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1351 lines
58 KiB
Python
1351 lines
58 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import tensorrt as trt
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from .._common import default_net, default_trtnet
|
|
from .._utils import (get_sm_version, str_dtype_to_np, str_dtype_to_trt,
|
|
trt_dtype_to_np)
|
|
from ..functional import (Tensor, _add_plugin_info, _create_tensor, cast, clip,
|
|
constant, flatten, layer_norm, matmul,
|
|
repeat_interleave, rms_norm, round, sum, view)
|
|
from ..layers.linear import ColumnLinear
|
|
from ..parameter import Parameter
|
|
from ..plugin import TRT_LLM_PLUGIN_NAMESPACE
|
|
from .mode import QuantMode
|
|
|
|
|
|
def smooth_quant_gemm(input: Tensor, weights: Tensor, scales_a: Tensor,
|
|
scales_b: Tensor, per_token_scaling: bool,
|
|
per_channel_scaling: bool, dtype: str) -> Tensor:
|
|
if not default_net().plugin_config.smooth_quant_gemm_plugin:
|
|
if per_token_scaling and input.size(0) == -1:
|
|
# WAR for DQ per-token scaling doesn't support dynamic shapes
|
|
|
|
scale_one = constant(np.array(1.0, dtype=np.float32))
|
|
input = dequantize(input, scale_one, 0, 'float32')
|
|
weights = dequantize(weights, scale_one, 0, 'float32')
|
|
result = matmul(input, weights, False, True, False)
|
|
scales = matmul(scales_a, scales_b, False, False, False)
|
|
result = result * scales
|
|
result = cast(result, dtype)
|
|
return result
|
|
else:
|
|
if not per_token_scaling:
|
|
scales_a = view(scales_a, [])
|
|
else:
|
|
scales_a = flatten(scales_a)
|
|
if not per_channel_scaling:
|
|
scales_b = view(scales_b, [])
|
|
else:
|
|
scales_b = flatten(scales_b)
|
|
input = dequantize(input, scales_a, 0, dtype)
|
|
weights = dequantize(weights, scales_b, 0, dtype)
|
|
result = matmul(input, weights, False, True, False)
|
|
return result
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'SmoothQuantGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
per_channel_scaling = 1 if per_channel_scaling else 0
|
|
per_channel_scaling = trt.PluginField(
|
|
"has_per_channel_scaling",
|
|
np.array(per_channel_scaling, dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
per_token_scaling = 1 if per_token_scaling else 0
|
|
per_token_scaling = trt.PluginField(
|
|
"has_per_token_scaling", np.array(per_token_scaling,
|
|
dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
p_dtype = default_net().plugin_config.smooth_quant_gemm_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection(
|
|
[per_channel_scaling, per_token_scaling, pf_type])
|
|
gemm_plug = plg_creator.create_plugin("sq_gemm", pfc)
|
|
plug_inputs = [
|
|
input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor,
|
|
scales_b.trt_tensor
|
|
]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
|
|
_add_plugin_info(layer, plg_creator, "sq_gemm", pfc)
|
|
if not default_net().strongly_typed:
|
|
layer.get_input(0).set_dynamic_range(-127, 127)
|
|
layer.get_input(1).set_dynamic_range(-127, 127)
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
|
|
def qserve_gemm_per_group(input: Tensor,
|
|
act_scales: Tensor,
|
|
weights: Tensor,
|
|
s1_scales: Tensor,
|
|
s2_scales: Tensor,
|
|
s2_zeros: Tensor,
|
|
group_size: int = 128) -> Tensor:
|
|
if not default_net().plugin_config.qserve_gemm_plugin:
|
|
raise TypeError("QServe Quant GEMM is only supported with plugin")
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'QServeGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
p_dtype = default_net().plugin_config.qserve_gemm_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pf_group_size = trt.PluginField("group_size",
|
|
np.array([group_size], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection([pf_type, pf_group_size])
|
|
gemm_plug = plg_creator.create_plugin("qserve_gemm", pfc)
|
|
plug_inputs = [
|
|
input.trt_tensor, weights.trt_tensor, s2_zeros.trt_tensor,
|
|
s2_scales.trt_tensor, s1_scales.trt_tensor, act_scales.trt_tensor
|
|
]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
|
|
_add_plugin_info(layer, plg_creator, "qserve_gemm", pfc)
|
|
if not default_net().strongly_typed:
|
|
# Useless. But must be kept otherwise leads to the following TRT API Usage error:
|
|
# input/output with DataType Int8 in network without Q/DQ layers must have dynamic range set when no calibrator is used
|
|
layer.get_input(0).set_dynamic_range(-128, 127)
|
|
layer.get_input(1).set_dynamic_range(-128, 127)
|
|
layer.get_input(2).set_dynamic_range(-128, 127)
|
|
layer.get_input(3).set_dynamic_range(-128, 127)
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
|
|
def qserve_gemm_per_channel(input: Tensor, act_scales: Tensor, act_sums: Tensor,
|
|
weights: Tensor, s1_scales: Tensor,
|
|
s1_szeros: Tensor) -> Tensor:
|
|
if not default_net().plugin_config.qserve_gemm_plugin:
|
|
raise TypeError("QServe Quant GEMM is only supported with plugin")
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'QServeGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
p_dtype = default_net().plugin_config.qserve_gemm_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pf_group_size = trt.PluginField("group_size", np.array([-1], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection([pf_type, pf_group_size])
|
|
gemm_plug = plg_creator.create_plugin("qserve_gemm", pfc)
|
|
|
|
plug_inputs = [
|
|
input.trt_tensor, weights.trt_tensor, s1_scales.trt_tensor,
|
|
s1_szeros.trt_tensor, act_sums.trt_tensor, act_scales.trt_tensor
|
|
]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
|
|
_add_plugin_info(layer, plg_creator, "qserve_gemm", pfc)
|
|
|
|
if not default_net().strongly_typed:
|
|
# Useless. But must be kept otherwise leads to the following TRT API Usage error:
|
|
# input/output with DataType Int8 in network without Q/DQ layers must have dynamic range set when no calibrator is used
|
|
layer.get_input(0).set_dynamic_range(-128, 127)
|
|
layer.get_input(1).set_dynamic_range(-128, 127)
|
|
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
|
|
def fp8_rowwise_gemm(input: Tensor, weights: Tensor, scales_a: Tensor,
|
|
scales_b: Tensor, per_token_scaling: bool,
|
|
per_channel_scaling: bool) -> Tensor:
|
|
if not default_net().plugin_config.fp8_rowwise_gemm_plugin:
|
|
raise TypeError("Fp8 Rowwise GEMM is only supported with plugin")
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'Fp8RowwiseGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
per_channel_scaling = 1 if per_channel_scaling else 0
|
|
per_channel_scaling = trt.PluginField(
|
|
"has_per_channel_scaling",
|
|
np.array(per_channel_scaling, dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
per_token_scaling = 1 if per_token_scaling else 0
|
|
per_token_scaling = trt.PluginField(
|
|
"has_per_token_scaling", np.array(per_token_scaling,
|
|
dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
p_dtype = default_net().plugin_config.fp8_rowwise_gemm_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection(
|
|
[per_channel_scaling, per_token_scaling, pf_type])
|
|
gemm_plug = plg_creator.create_plugin("fp8_rowwise_gemm", pfc)
|
|
plug_inputs = [
|
|
input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor,
|
|
scales_b.trt_tensor
|
|
]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
|
|
_add_plugin_info(layer, plg_creator, "fp8_rowwise_gemm", pfc)
|
|
if not default_net().strongly_typed:
|
|
layer.get_input(0).set_dynamic_range(-448, 448)
|
|
layer.get_input(1).set_dynamic_range(-448, 448)
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
|
|
def weight_only_quant_matmul(input: Tensor,
|
|
weights: Tensor,
|
|
scales: Tensor,
|
|
weightTypeId: int,
|
|
dtype: str = 'float16',
|
|
transa: bool = False,
|
|
transb: bool = False) -> Tensor:
|
|
if not default_net(
|
|
).plugin_config.weight_only_quant_matmul_plugin or transa or transb:
|
|
scale_axis = 0 if transb else 1
|
|
if weights.dtype != trt.int8:
|
|
# Q->DQ
|
|
weights = quantize(weights, scales, dtype='int8', axis=1)
|
|
weights = dequantize(weights, scales, scale_axis, input.dtype)
|
|
else:
|
|
weights = dequantize(weights, scales, scale_axis, input.dtype)
|
|
|
|
res = matmul(input, weights, transa=transa, transb=transb)
|
|
return cast(res, dtype)
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'WeightOnlyQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
weight_type_id = trt.PluginField("weight_type_id",
|
|
np.array(weightTypeId, dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
p_dtype = default_net().plugin_config.weight_only_quant_matmul_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection([pf_type, weight_type_id])
|
|
matmul_plug = plg_creator.create_plugin("woq_matmul", pfc)
|
|
plug_inputs = [input.trt_tensor, weights.trt_tensor, scales.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)
|
|
_add_plugin_info(layer, plg_creator, "woq_matmul", pfc)
|
|
if not default_net().strongly_typed:
|
|
layer.get_input(1).set_dynamic_range(-127, 127)
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
|
|
def weight_only_groupwise_quant_matmul(input: Tensor,
|
|
pre_quant_scale: Tensor,
|
|
weights: Tensor,
|
|
scales: Tensor,
|
|
zeros: Tensor,
|
|
biases: Tensor,
|
|
alpha: Parameter,
|
|
quant_algo: int,
|
|
group_size: int,
|
|
dtype: str = 'float16') -> Tensor:
|
|
if not default_net(
|
|
).plugin_config.weight_only_groupwise_quant_matmul_plugin:
|
|
scales = repeat_interleave(scales, group_size, 0)
|
|
weights = quantize(weights, scales, dtype='int8', axis=1)
|
|
weights = dequantize(weights, scales, 1, input.dtype)
|
|
|
|
if quant_algo & 8:
|
|
# fp8_alpha
|
|
input = input * alpha.value
|
|
if quant_algo & 4:
|
|
# pre quant
|
|
input = input * pre_quant_scale
|
|
elif quant_algo & 2:
|
|
# zero
|
|
zeros = repeat_interleave(zeros, group_size, 0)
|
|
weights += zeros
|
|
res = matmul(input, weights)
|
|
if quant_algo & 1:
|
|
# bias
|
|
res += biases
|
|
|
|
return cast(res, dtype)
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'WeightOnlyGroupwiseQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
quant_algo_ = trt.PluginField("quant_algo",
|
|
np.array(quant_algo, dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
group_size_ = trt.PluginField("group_size",
|
|
np.array(group_size, dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
if alpha:
|
|
alpha.is_buffer = True
|
|
alpha_value = alpha.raw_value[0]
|
|
else:
|
|
alpha_value = 1.0
|
|
|
|
alpha_ = trt.PluginField("alpha", np.array(alpha_value,
|
|
dtype=np.float32),
|
|
trt.PluginFieldType.FLOAT32)
|
|
|
|
p_dtype = default_net(
|
|
).plugin_config.weight_only_groupwise_quant_matmul_plugin
|
|
pf_type_ = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection(
|
|
[pf_type_, quant_algo_, group_size_, alpha_])
|
|
|
|
matmul_plug = plg_creator.create_plugin("woq_groupwise_matmul", pfc)
|
|
|
|
# quant_algo = use_int8_weight * 16 + fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
|
|
plug_inputs = [input.trt_tensor]
|
|
|
|
# Flags for indicating whether the corresponding inputs are applied in quant_algo
|
|
# quant_algo = use_int8_weight * INT8_WEIGHT + fp8_alpha * FP8_ALPHA + pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS
|
|
# Here use_int8_weight, pre_quant_scale, zero and bias are boolean type
|
|
BIAS = 1
|
|
ZERO = 2
|
|
PRE_QUANT_SCALE = 4
|
|
|
|
if quant_algo & PRE_QUANT_SCALE:
|
|
plug_inputs += [pre_quant_scale.trt_tensor]
|
|
|
|
plug_inputs += [weights.trt_tensor, scales.trt_tensor]
|
|
|
|
if quant_algo & ZERO:
|
|
plug_inputs += [zeros.trt_tensor]
|
|
if quant_algo & BIAS:
|
|
plug_inputs += [biases.trt_tensor]
|
|
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)
|
|
_add_plugin_info(layer, plg_creator, "woq_groupwise_matmul", pfc)
|
|
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
|
|
# TODO: Should be renamed to layer_norm_quantize.
|
|
def smooth_quant_layer_norm(input: Tensor,
|
|
normalized_shape: Union[int, Tuple[int]],
|
|
weight: Optional[Tensor] = None,
|
|
bias: Optional[Tensor] = None,
|
|
scale: Optional[Tensor] = None,
|
|
eps: float = 1e-05,
|
|
use_diff_of_squares: bool = True,
|
|
dynamic_act_scaling: bool = False) -> Tensor:
|
|
if not default_net().plugin_config.layernorm_quantization_plugin:
|
|
dtype = trt_dtype_to_np(input.dtype)
|
|
if weight is None:
|
|
weight = constant(np.ones(normalized_shape, dtype=dtype))
|
|
if bias is None:
|
|
bias = constant(np.zeros(normalized_shape, dtype=dtype))
|
|
result = layer_norm(input, normalized_shape, weight, bias, eps,
|
|
use_diff_of_squares)
|
|
if not dynamic_act_scaling:
|
|
return quantize_tensor(result, scale)
|
|
else:
|
|
return quantize_per_token(result)
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'LayernormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
|
|
trt.PluginFieldType.FLOAT32)
|
|
use_diff_of_squares = trt.PluginField(
|
|
"use_diff_of_squares",
|
|
np.array([int(use_diff_of_squares)], dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
dyn_act_scaling = trt.PluginField(
|
|
"dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
p_dtype = default_net().plugin_config.layernorm_quantization_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pfc = trt.PluginFieldCollection(
|
|
[eps, use_diff_of_squares, dyn_act_scaling, pf_type])
|
|
layernorm_plug = plg_creator.create_plugin("layernorm_quantized", pfc)
|
|
normalized_shape = [normalized_shape] if isinstance(
|
|
normalized_shape, int) else normalized_shape
|
|
if weight is None:
|
|
weight = constant(
|
|
np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
if bias is None:
|
|
bias = constant(
|
|
np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
|
|
# LayerNorm plugin only supports float32 scale
|
|
scale = cast(scale, "float32")
|
|
plug_inputs = [
|
|
input.trt_tensor, weight.trt_tensor, bias.trt_tensor,
|
|
scale.trt_tensor
|
|
]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, layernorm_plug)
|
|
if not default_net().strongly_typed:
|
|
layer.get_output(0).set_dynamic_range(-127, 127)
|
|
_add_plugin_info(layer, plg_creator, "layernorm_quantized", pfc)
|
|
if not dynamic_act_scaling:
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
return _create_tensor(layer.get_output(0),
|
|
layer), _create_tensor(layer.get_output(1), layer)
|
|
|
|
|
|
# TODO: Should be renamed to rms_norm_quantize. This is also used by QServe.
|
|
def smooth_quant_rms_norm(
|
|
input: Tensor,
|
|
normalized_shape: Union[int, Tuple[int]],
|
|
weight: Optional[Tensor] = None,
|
|
bias: Optional[Tensor] = None,
|
|
scale: Optional[Tensor] = None,
|
|
clamp_val: Optional[Tensor] = None,
|
|
eps: float = 1e-05,
|
|
dynamic_act_scaling: bool = False,
|
|
scale_dtype='float32',
|
|
sum_per_token: bool = False,
|
|
sum_dtype='float32'
|
|
) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
|
|
if sum_per_token and not dynamic_act_scaling:
|
|
raise ValueError(
|
|
"sum_per_token is only allowed if dynamic_act_scaling is enabled!")
|
|
|
|
if not default_net().plugin_config.rmsnorm_quantization_plugin:
|
|
result = rms_norm(input, normalized_shape, 1, weight, eps)
|
|
if bias is not None:
|
|
result += bias
|
|
if not dynamic_act_scaling:
|
|
return quantize_tensor(result, scale)
|
|
else:
|
|
return quantize_per_token(result, clamp_val, scale_dtype,
|
|
sum_per_token, sum_dtype)
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'RmsnormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
output_type = trt.PluginField("out_type_id",
|
|
np.array([int(trt.int8)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
quant_mode = trt.PluginField(
|
|
"quant_mode",
|
|
np.array([int(QuantMode.use_smooth_quant(per_token=True))],
|
|
np.int32), trt.PluginFieldType.INT32)
|
|
clamp_enabled = trt.PluginField(
|
|
"clamp_enabled", np.array([clamp_val is not None], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
|
|
trt.PluginFieldType.FLOAT32)
|
|
|
|
dyn_act_scaling = trt.PluginField(
|
|
"dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
sum_per_token_pf = trt.PluginField(
|
|
"sum_per_token", np.array([int(sum_per_token)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
p_dtype = default_net().plugin_config.rmsnorm_quantization_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pfc = trt.PluginFieldCollection([
|
|
eps, dyn_act_scaling, sum_per_token_pf, clamp_enabled, quant_mode,
|
|
pf_type, output_type
|
|
])
|
|
rmsnorm_plug = plg_creator.create_plugin("rmsnorm_quantized", pfc)
|
|
normalized_shape = [normalized_shape] if isinstance(
|
|
normalized_shape, int) else normalized_shape
|
|
if weight is None:
|
|
weight = constant(
|
|
np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
if bias is None:
|
|
bias = constant(
|
|
np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
|
|
# TODO: Why not fuse scale (which seems to be a per-tensor scaling factor of the original values) into weight?
|
|
if scale is None:
|
|
scale = constant(np.ones(1, dtype=str_dtype_to_np(p_dtype)))
|
|
|
|
# RMS Norm Plugin only supports float32 scale
|
|
scale = cast(scale, "float32")
|
|
|
|
plug_inputs = [
|
|
input.trt_tensor, weight.trt_tensor, bias.trt_tensor,
|
|
scale.trt_tensor
|
|
]
|
|
if clamp_val:
|
|
plug_inputs += [clamp_val.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, rmsnorm_plug)
|
|
if not default_net().strongly_typed:
|
|
layer.get_output(0).set_dynamic_range(-127, 127)
|
|
_add_plugin_info(layer, plg_creator, "rmsnorm_quantized", pfc)
|
|
if not dynamic_act_scaling:
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
output_quantized = _create_tensor(layer.get_output(0), layer)
|
|
output_scales = _create_tensor(layer.get_output(1), layer)
|
|
|
|
# TODO: The plugin should be able to directly output float16 scales
|
|
if str_dtype_to_trt(scale_dtype) != output_scales.dtype:
|
|
output_scales = cast(output_scales, scale_dtype)
|
|
|
|
if not sum_per_token:
|
|
return output_quantized, output_scales
|
|
|
|
output_sums = _create_tensor(layer.get_output(2), layer)
|
|
# TODO: The plugin should be able to directly output float16 sums
|
|
if str_dtype_to_trt(sum_dtype) != output_sums.dtype:
|
|
output_sums = cast(output_sums, sum_dtype)
|
|
|
|
return output_quantized, output_scales, output_sums
|
|
|
|
|
|
def fp8_rowwise_rms_norm(input: Tensor,
|
|
normalized_shape: Union[int, Tuple[int]],
|
|
weight: Optional[Tensor] = None,
|
|
bias: Optional[Tensor] = None,
|
|
scale: Optional[Tensor] = None,
|
|
clamp_val: Optional[Tensor] = None,
|
|
eps: float = 1e-05,
|
|
dynamic_act_scaling: bool = True) -> Tensor:
|
|
if not default_net().plugin_config.rmsnorm_quantization_plugin:
|
|
raise TypeError("Fp8 Rowwise Rms Norm is only supported with plugin")
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'RmsnormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
output_type = trt.PluginField("out_type_id",
|
|
np.array([int(trt.fp8)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
quant_mode = trt.PluginField(
|
|
"quant_mode",
|
|
np.array([int(QuantMode.from_description(use_fp8_rowwise=True))],
|
|
np.int32), trt.PluginFieldType.INT32)
|
|
clamp_enabled = trt.PluginField(
|
|
"clamp_enabled", np.array([clamp_val is not None], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
|
|
trt.PluginFieldType.FLOAT32)
|
|
|
|
dyn_act_scaling = trt.PluginField(
|
|
"dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
sum_per_token_pf = trt.PluginField("sum_per_token",
|
|
np.array([int(False)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
p_dtype = default_net().plugin_config.rmsnorm_quantization_plugin
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection([
|
|
eps, dyn_act_scaling, sum_per_token_pf, clamp_enabled, quant_mode,
|
|
pf_type, output_type
|
|
])
|
|
|
|
rmsnorm_plug = plg_creator.create_plugin("rmsnorm_quantized", pfc)
|
|
normalized_shape = [normalized_shape] if isinstance(
|
|
normalized_shape, int) else normalized_shape
|
|
if weight is None:
|
|
weight = constant(
|
|
np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
if bias is None:
|
|
bias = constant(
|
|
np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
if scale is None:
|
|
scale = constant(np.ones((1, ), dtype=str_dtype_to_np(p_dtype)))
|
|
|
|
# RMS Norm Plugin only supports float32 scale
|
|
scale = cast(scale, "float32")
|
|
plug_inputs = [
|
|
input.trt_tensor, weight.trt_tensor, bias.trt_tensor,
|
|
scale.trt_tensor
|
|
]
|
|
if clamp_val:
|
|
plug_inputs += [clamp_val.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, rmsnorm_plug)
|
|
if not default_net().strongly_typed:
|
|
layer.get_output(0).set_dynamic_range(-448, 448)
|
|
_add_plugin_info(layer, plg_creator, "rmsnorm_quantized", pfc)
|
|
if not dynamic_act_scaling:
|
|
return _create_tensor(layer.get_output(0), layer)
|
|
|
|
return _create_tensor(layer.get_output(0),
|
|
layer), _create_tensor(layer.get_output(1), layer)
|
|
|
|
|
|
def fused_layernorm(
|
|
input: Tensor,
|
|
normalized_shape: Union[int, Tuple[int]],
|
|
residual: Optional[Tensor] = None,
|
|
weight: Optional[Tensor] = None,
|
|
# beta: Optional[Tensor] = None,
|
|
# bias: Optional[Tensor] = None,
|
|
scale: Optional[Tensor] = None,
|
|
eps: float = 1e-05,
|
|
p_dtype: str = 'float16',
|
|
need_fp32_output: bool = False) -> Tensor:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'FusedLayernorm', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
|
|
trt.PluginFieldType.FLOAT32)
|
|
pf_type = trt.PluginField(
|
|
"type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
need_fp32_output_value = need_fp32_output
|
|
need_fp32_output = trt.PluginField(
|
|
"need_fp32_output", np.array([int(need_fp32_output_value)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
need_quantize_value = scale is not None
|
|
need_quantize = trt.PluginField(
|
|
"need_quantize", np.array([int(need_quantize_value)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pfc = trt.PluginFieldCollection(
|
|
[eps, need_fp32_output, need_quantize, pf_type])
|
|
fused_layernorm_plug = plg_creator.create_plugin("fused_layernorm", pfc)
|
|
normalized_shape = [normalized_shape] if isinstance(
|
|
normalized_shape, int) else normalized_shape
|
|
if weight is None:
|
|
weight = constant(
|
|
np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
# if beta is None:
|
|
# beta = constant(
|
|
# np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
# if bias is None:
|
|
# bias = constant(
|
|
# np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
|
|
if need_quantize_value:
|
|
plug_inputs = [
|
|
input.trt_tensor, residual.trt_tensor, weight.trt_tensor,
|
|
scale.trt_tensor
|
|
]
|
|
else:
|
|
plug_inputs = [
|
|
input.trt_tensor,
|
|
residual.trt_tensor,
|
|
weight.trt_tensor,
|
|
]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, fused_layernorm_plug)
|
|
_add_plugin_info(layer, plg_creator, "fused_layernorm", pfc)
|
|
if not need_quantize_value:
|
|
return _create_tensor(layer.get_output(0),
|
|
layer), _create_tensor(layer.get_output(1), layer)
|
|
return _create_tensor(layer.get_output(0), layer), _create_tensor(
|
|
layer.get_output(1), layer), _create_tensor(layer.get_output(2), layer)
|
|
|
|
|
|
def quantize(input: Tensor,
|
|
scale_factor: Tensor,
|
|
dtype: str,
|
|
axis: int = -1) -> Tensor:
|
|
layer = default_trtnet().add_quantize(input.trt_tensor,
|
|
scale_factor.trt_tensor,
|
|
str_dtype_to_trt(dtype))
|
|
layer.axis = axis
|
|
|
|
output = _create_tensor(layer.get_output(0), layer)
|
|
|
|
return output
|
|
|
|
|
|
def dequantize(input: Tensor,
|
|
scale_factor: Tensor,
|
|
axis: int = -1,
|
|
output_type: Union[str, trt.DataType] = 'float16') -> Tensor:
|
|
|
|
if isinstance(output_type, str):
|
|
output_type = str_dtype_to_trt(output_type)
|
|
|
|
layer = default_trtnet().add_dequantize(input.trt_tensor,
|
|
scale_factor.trt_tensor,
|
|
output_type)
|
|
layer.axis = axis
|
|
|
|
if not default_net().strongly_typed:
|
|
layer.precision = input.dtype
|
|
|
|
output = _create_tensor(layer.get_output(0), layer)
|
|
|
|
return output
|
|
|
|
|
|
def quantize_per_token(
|
|
x: Tensor,
|
|
clamp_val: Optional[Tensor] = None,
|
|
scale_dtype='float32',
|
|
sum_per_token: bool = False,
|
|
sum_dtype='float32',
|
|
) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
|
|
if not default_net().plugin_config.quantize_per_token_plugin:
|
|
x = cast(x, 'float32')
|
|
xmax = x.abs().max(-1, keepdim=True)
|
|
scales = xmax / 127.0
|
|
out = x * 127.0 / xmax
|
|
out = round(out)
|
|
out = clip(out, -128, 127)
|
|
quantized = cast(out, 'int8')
|
|
if not sum_per_token:
|
|
return quantized, scales
|
|
sums = sum(x, -1, keepdim=True)
|
|
if sum_dtype is not None and str_dtype_to_trt(sum_dtype) != sums.dtype:
|
|
sums = cast(sums, sum_dtype)
|
|
return quantized, scales, sums
|
|
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
output_type = trt.PluginField("type_id", np.array([int(trt.int8)],
|
|
np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
quant_mode = trt.PluginField(
|
|
"quant_mode",
|
|
np.array([int(QuantMode.use_smooth_quant(per_token=True))], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
clamp_enabled = trt.PluginField("clamp_enabled",
|
|
np.array([clamp_val is not None], np.int8),
|
|
trt.PluginFieldType.INT8)
|
|
|
|
sum_per_token_pf = trt.PluginField("sum_per_token",
|
|
np.array([int(sum_per_token)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
|
|
pfc = trt.PluginFieldCollection(
|
|
[output_type, quant_mode, clamp_enabled, sum_per_token_pf])
|
|
quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin", pfc)
|
|
|
|
plug_inputs = [x.trt_tensor]
|
|
if clamp_val:
|
|
plug_inputs += [clamp_val.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
|
|
if not default_net().strongly_typed:
|
|
layer.get_output(0).set_dynamic_range(-127, 127)
|
|
_add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc)
|
|
|
|
quantized = _create_tensor(layer.get_output(0), layer)
|
|
scales = _create_tensor(layer.get_output(1), layer)
|
|
|
|
# TODO: The plugin should be able to directly output float16 scales to avoid a cast
|
|
if scale_dtype is not None and str_dtype_to_trt(
|
|
scale_dtype) != scales.dtype:
|
|
scales = cast(scales, scale_dtype)
|
|
if not sum_per_token:
|
|
return quantized, scales
|
|
|
|
sums = _create_tensor(layer.get_output(2), layer)
|
|
# TODO: The plugin should be able to directly output float16 sums to avoid a cast
|
|
if sum_dtype is not None and str_dtype_to_trt(sum_dtype) != sums.dtype:
|
|
sums = cast(sums, sum_dtype)
|
|
|
|
return quantized, scales, sums
|
|
|
|
|
|
def quantize_fp8_per_token(x: Tensor,
|
|
clamp_val: Optional[Tensor] = None) -> Tuple[Tensor]:
|
|
if not default_net().plugin_config.quantize_per_token_plugin:
|
|
x = cast(x, 'float32')
|
|
xmax = x.abs().max(-1, keepdim=True)
|
|
scale = xmax / 448.0
|
|
out = x * 448.0 / xmax
|
|
out = round(out)
|
|
out = clip(out, -448, 448)
|
|
quantized_out = cast(out, 'fp8')
|
|
return quantized_out, scale
|
|
else:
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
output_type = trt.PluginField("type_id",
|
|
np.array([int(trt.fp8)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
quant_mode = trt.PluginField(
|
|
"quant_mode",
|
|
np.array([int(QuantMode.from_description(use_fp8_rowwise=True))],
|
|
np.int32), trt.PluginFieldType.INT32)
|
|
clamp_enabled = trt.PluginField(
|
|
"clamp_enabled", np.array([clamp_val is not None], np.int8),
|
|
trt.PluginFieldType.INT8)
|
|
sum_per_token_pf = trt.PluginField("sum_per_token",
|
|
np.array([int(False)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pfc = trt.PluginFieldCollection(
|
|
[output_type, quant_mode, clamp_enabled, sum_per_token_pf])
|
|
quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin",
|
|
pfc)
|
|
|
|
plug_inputs = [x.trt_tensor]
|
|
if clamp_val:
|
|
plug_inputs += [clamp_val.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
|
|
if not default_net().strongly_typed:
|
|
layer.get_output(0).set_dynamic_range(-448, 448)
|
|
_add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc)
|
|
|
|
quantized = _create_tensor(layer.get_output(0), layer)
|
|
scales = _create_tensor(layer.get_output(1), layer)
|
|
|
|
return quantized, scales
|
|
|
|
|
|
def quantize_tensor(x, scale):
|
|
if not default_net().plugin_config.quantize_tensor_plugin:
|
|
if scale.dtype == str_dtype_to_trt('float32'):
|
|
x = cast(x, 'float32')
|
|
scaled = x * scale
|
|
rounded = round(scaled)
|
|
clipped = clip(rounded, -128, 127)
|
|
quantized = cast(clipped, 'int8')
|
|
else:
|
|
scale = cast(scale, 'float32')
|
|
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'QuantizeTensor', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
pfc = trt.PluginFieldCollection([])
|
|
quantize_plug = plg_creator.create_plugin("quantize_tensor_plugin", pfc)
|
|
|
|
plug_inputs = [x.trt_tensor, scale.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
|
|
if not default_net().strongly_typed:
|
|
layer.get_output(0).set_dynamic_range(-127, 127)
|
|
_add_plugin_info(layer, plg_creator, "quantize_tensor_plugin", pfc)
|
|
|
|
quantized = _create_tensor(layer.get_output(0), layer)
|
|
return quantized
|
|
|
|
|
|
def symmetric_quantize_last_axis_of_batched_matrix(weight, quant_mode):
|
|
amax = weight.abs().max(dim=0)[0].to(weight.dtype)
|
|
if quant_mode == torch.int8:
|
|
scale = amax / 128.
|
|
qweight = torch.clamp((weight / scale).round(), -128, 127).char()
|
|
qweight = qweight.T.reshape(weight.shape)
|
|
else:
|
|
scale = amax / 8.
|
|
qweight = torch.clamp((weight / scale).round(), -8, 7).char()
|
|
qweight[qweight < 0] += 16
|
|
qweight = qweight.T.view(torch.uint8)
|
|
qweight = (qweight[:, 1::2] * 16 + qweight[:, ::2]).view(torch.int8)
|
|
qweight = qweight.reshape(weight.shape[0], weight.shape[1] // 2)
|
|
return qweight, scale
|
|
|
|
|
|
def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
|
|
quant_mode: torch.dtype,
|
|
act_dtype: torch.dtype,
|
|
sm_: int = -1) -> torch.Tensor:
|
|
sm_ = sm_ if sm_ > 0 else get_sm_version()
|
|
if len(tensor.shape) == 2:
|
|
tensor = tensor.unsqueeze(0)
|
|
elif sm_ >= 90:
|
|
sm_ = 80
|
|
if sm_ == 120:
|
|
sm_ = 80
|
|
|
|
permutation_map = {
|
|
"16_8": [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15],
|
|
"16_4": [
|
|
0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12,
|
|
13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31
|
|
],
|
|
"8_4": [
|
|
0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10,
|
|
11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31
|
|
]
|
|
}
|
|
|
|
# permute_B_rows_for_mixed_gemm
|
|
BITS_PER_ELT_A = 8 if act_dtype == torch.float8_e4m3fn else 16
|
|
BITS_PER_ELT_B = 4 if quant_mode == torch.quint4x2 else 8
|
|
MMA_SHAPE_N = 8
|
|
B_ROWS_PER_MMA = 8 * 16 // BITS_PER_ELT_B
|
|
|
|
num_experts = tensor.shape[0]
|
|
num_rows = tensor.shape[1]
|
|
num_cols = tensor.shape[2]
|
|
|
|
assert (sm_ >= 75)
|
|
assert (num_rows % B_ROWS_PER_MMA == 0)
|
|
assert (num_cols % MMA_SHAPE_N == 0)
|
|
|
|
row_idx_list = [
|
|
(row_idx // B_ROWS_PER_MMA) * B_ROWS_PER_MMA +
|
|
permutation_map[f"{BITS_PER_ELT_A}_{BITS_PER_ELT_B}"][row_idx %
|
|
B_ROWS_PER_MMA]
|
|
for row_idx in range(num_rows)
|
|
]
|
|
tensor = tensor[:, row_idx_list, :]
|
|
|
|
# subbyte_transpose
|
|
original_shape = tensor.shape
|
|
if BITS_PER_ELT_B == 4:
|
|
tensor = tensor.view(torch.uint8)
|
|
high_tensor = (tensor >> 4).permute(0, 2, 1).unsqueeze(2)
|
|
low_tensor = ((tensor << 4) >> 4).permute(0, 2, 1).unsqueeze(2)
|
|
new_tensor = torch.cat([low_tensor, high_tensor],
|
|
dim=2).reshape(tensor.shape[0], -1,
|
|
tensor.shape[1])
|
|
new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16
|
|
tensor = new_tensor.view(torch.int8).reshape(original_shape)
|
|
else:
|
|
tensor = tensor.permute(0, 2, 1).reshape(original_shape)
|
|
|
|
# interleave_column_major_tensor
|
|
interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
|
|
if interleave > 1 and sm_ < 90:
|
|
rows_per_tile = 128 * 8 // BITS_PER_ELT_A
|
|
elts_in_int32 = 32 // BITS_PER_ELT_B
|
|
|
|
assert (num_rows % elts_in_int32 == 0)
|
|
assert (num_rows % rows_per_tile == 0)
|
|
|
|
tensor = tensor.reshape(num_experts, -1, interleave,
|
|
num_rows // rows_per_tile,
|
|
rows_per_tile * 4 // elts_in_int32)
|
|
tensor = tensor.permute(0, 1, 3, 2, 4).reshape(original_shape)
|
|
|
|
# add_bias_and_interleave_quantized_tensor_inplace
|
|
if BITS_PER_ELT_B == 8:
|
|
tensor += -256 * (tensor > 127).byte() + 128
|
|
tensor = tensor.reshape(-1, 4)[:, [0, 2, 1, 3]].reshape(tensor.shape)
|
|
elif BITS_PER_ELT_B == 4:
|
|
tensor = tensor.view(torch.uint8)
|
|
high_tensor = (tensor >> 4).unsqueeze(-1)
|
|
low_tensor = ((tensor << 4) >> 4).unsqueeze(-1)
|
|
new_tensor = torch.cat([low_tensor, high_tensor],
|
|
dim=-1).reshape(tensor.shape[0], tensor.shape[1],
|
|
-1)
|
|
new_tensor = new_tensor.reshape(
|
|
-1, 8)[:, [0, 2, 4, 6, 1, 3, 5, 7]].reshape(new_tensor.shape)
|
|
new_tensor += -16 * (new_tensor > 7).byte() + 8
|
|
new_tensor = new_tensor[:, :, 0::2] + new_tensor[:, :, 1::2] * 16
|
|
tensor = new_tensor.view(torch.int8)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return tensor.squeeze(0).contiguous()
|
|
|
|
|
|
def validate_group_size(layer):
|
|
# TODO: Remove this function and its usage after W4A8-AWQ with group_size = 64 is implemented.
|
|
W4A8_AWQ = 8
|
|
if layer.quant_algo & W4A8_AWQ and layer.group_size == 64:
|
|
raise NotImplementedError(
|
|
"W4A8_AWQ with group_size = 64 is not implemented yet!")
|
|
|
|
|
|
def unpack_int32_into_int8(w_packed, autoawq_reorder=False):
|
|
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
|
|
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
|
|
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
|
|
w_packed_int4x2.shape[1] * 2,
|
|
dtype=torch.int8)
|
|
w_unpacked[:, ::2] = w_packed_int4x2 % 16
|
|
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
|
|
if autoawq_reorder:
|
|
w_unpacked = w_unpacked.view(-1, 8)[:, [0, 4, 1, 5, 2, 6, 3, 7]].view(
|
|
w_unpacked.shape)
|
|
return w_unpacked.contiguous()
|
|
|
|
|
|
def change_qkv_leading_dim(w, num_heads):
|
|
if w.dim() == 1:
|
|
w = w.reshape(num_heads, 3, -1)
|
|
w = w.transpose(0, 1).reshape(-1)
|
|
else:
|
|
shape = w.shape
|
|
head_dim = shape[1] // (3 * num_heads)
|
|
w = w.reshape(-1, num_heads, 3, head_dim)
|
|
w = w.transpose(1, 2).reshape(shape[0], -1)
|
|
return w
|
|
|
|
|
|
def pad_like(w, target_shape, value=0):
|
|
if w.shape != target_shape:
|
|
pad_dim = []
|
|
for dim in range(len(target_shape)):
|
|
current_dim = -1 - dim
|
|
pad_dim.append(0)
|
|
pad_dim.append(
|
|
max(0, target_shape[current_dim] - w.shape[current_dim]))
|
|
res = F.pad(w, pad_dim, value=value)
|
|
return res
|
|
else:
|
|
return w
|
|
|
|
|
|
def postprocess_weight_only(tllm_key, weights, quant_mode, layer):
|
|
if weights.dim() > 2:
|
|
v = weights.transpose(-1, -2)
|
|
else:
|
|
v = weights.t()
|
|
|
|
tp_dim = 1 if isinstance(layer, ColumnLinear) else 0
|
|
if "weight" in tllm_key:
|
|
if layer.is_padded:
|
|
split_size = layer.out_features if tp_dim == 1 else layer.in_features
|
|
v = torch.split(v, split_size, tp_dim)[layer.tp_rank]
|
|
v = pad_like(v, (layer.in_features, layer.out_features))
|
|
processed_torch_weights, torch_weight_scales = \
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
|
v.contiguous(), quant_mode)
|
|
return {
|
|
tllm_key: processed_torch_weights,
|
|
tllm_key.replace("weight", "per_channel_scale"):
|
|
torch_weight_scales,
|
|
}
|
|
else:
|
|
if layer.is_padded and tp_dim == 1:
|
|
weights = torch.split(weights, layer.out_features,
|
|
tp_dim)[layer.tp_rank]
|
|
weights = pad_like(weights, (layer.out_features, ))
|
|
return {tllm_key: weights} # Bias
|
|
|
|
|
|
def postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype, layer,
|
|
**kwargs):
|
|
using_head_as_leading_dim = kwargs.get("using_head_as_leading_dim", False)
|
|
config = kwargs.get("config", None)
|
|
use_autoawq = kwargs.get("use_autoawq", None)
|
|
num_heads = config.num_attention_heads
|
|
USE_GPTQ = layer.prequant_scaling_factor is None and use_autoawq is None
|
|
USE_HF_AWQ = layer.prequant_scaling_factor is None and use_autoawq is not None
|
|
USE_MODELOPT_AWQ = layer.prequant_scaling_factor is not None
|
|
USE_INT8_WEIGHT = layer.quant_algo & 16
|
|
|
|
tp_dim = 1 if isinstance(layer, ColumnLinear) else 0
|
|
is_qkv = layer.is_qkv if hasattr(layer, "is_qkv") else False
|
|
|
|
if using_head_as_leading_dim:
|
|
assert config.num_attention_heads == config.num_key_value_heads, "using_head_as_leading_dim require head_size to be multiple of 3."
|
|
if tllm_key.endswith("weights_scaling_factor"):
|
|
# TODO: Remove reshaping after modelopt optimizes scale shape
|
|
if is_qkv:
|
|
for idx, w in enumerate(weights):
|
|
scales = w.to(torch_dtype)
|
|
scales = scales.reshape(-1,
|
|
layer.weights_scaling_factor.shape[0]).T
|
|
scales = scales.chunk(layer.tp_size, 1)[layer.tp_rank]
|
|
weights[idx] = scales
|
|
weights = torch.cat(weights, dim=1)
|
|
else:
|
|
scales = weights.to(torch_dtype)
|
|
scales_shape = [
|
|
layer.weights_scaling_factor.shape[1],
|
|
layer.weights_scaling_factor.shape[0]
|
|
]
|
|
scales_shape[1 - tp_dim] *= layer.tp_size
|
|
scales = scales.reshape(scales_shape).T
|
|
weights = scales.chunk(layer.tp_size, tp_dim)[layer.tp_rank]
|
|
if is_qkv and isinstance(weights, list) and len(weights) >= 3:
|
|
if USE_MODELOPT_AWQ:
|
|
if tllm_key.endswith("prequant_scaling_factor"):
|
|
weights = weights[0]
|
|
else:
|
|
weights = torch.cat(weights, dim=0)
|
|
elif len(weights) > 3:
|
|
weights = [
|
|
torch.cat(weights[i::len(weights) // 3], dim=1)
|
|
for i in range(len(weights) // 3)
|
|
]
|
|
|
|
if tllm_key.endswith("bias"):
|
|
if is_qkv and isinstance(weights, list):
|
|
weights = torch.cat(weights)
|
|
if layer.is_padded:
|
|
weights = pad_like(weights, layer.bias.shape)
|
|
if using_head_as_leading_dim:
|
|
weights = change_qkv_leading_dim(weights, num_heads)
|
|
results = {tllm_key: weights.to(torch_dtype)}
|
|
elif tllm_key.endswith("weight"):
|
|
if not USE_INT8_WEIGHT:
|
|
# 4 bit quantization
|
|
if USE_GPTQ:
|
|
qweight = unpack_int32_into_int8(weights[0].T).T - 8
|
|
elif USE_HF_AWQ:
|
|
qweight = unpack_int32_into_int8(weights[0], True) - 8
|
|
else:
|
|
qweight = unpack_int32_into_int8(weights.T)
|
|
qweight -= (qweight >> 4) << 4
|
|
qweight = qweight.view(torch.uint8)
|
|
elif USE_INT8_WEIGHT and USE_GPTQ:
|
|
# 8 bit quantization (only consider INT8 GPTQ here)
|
|
qweight = (
|
|
weights[0].T.contiguous().view(torch.uint8).T.contiguous() -
|
|
128).to(torch.int8)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Unsupported quantization mode for weight.")
|
|
|
|
if using_head_as_leading_dim:
|
|
qweight = change_qkv_leading_dim(qweight, num_heads)
|
|
if layer.is_padded:
|
|
qweight = torch.split(qweight, layer.out_features,
|
|
tp_dim)[layer.tp_rank]
|
|
qweight = pad_like(qweight, (layer.in_features, layer.out_features))
|
|
# pack int8 tensor to packed int4
|
|
if not USE_INT8_WEIGHT:
|
|
qweight = (qweight[:, 1::2] * 16 + qweight[:, ::2]).view(torch.int8)
|
|
weight_type = torch.int8 if USE_INT8_WEIGHT else torch.quint4x2
|
|
qweight = preprocess_weights_for_mixed_gemm(
|
|
qweight, weight_type, torch.float16).view(torch_dtype)
|
|
results = {tllm_key: qweight}
|
|
|
|
# scales and zeros for GPTQ and HF-AWQ
|
|
if USE_GPTQ or USE_HF_AWQ:
|
|
scales = weights[1].to(torch_dtype)
|
|
if USE_INT8_WEIGHT:
|
|
qzeros = weights[2].view(torch.uint8)
|
|
else:
|
|
qzeros = unpack_int32_into_int8(weights[2], USE_HF_AWQ)
|
|
if using_head_as_leading_dim:
|
|
scales = change_qkv_leading_dim(scales, num_heads)
|
|
qzeros = change_qkv_leading_dim(qzeros, num_heads)
|
|
if layer.is_padded:
|
|
scales = torch.split(scales,
|
|
layer.weights_scaling_factor.shape[tp_dim],
|
|
tp_dim)[layer.tp_rank]
|
|
scales = pad_like(scales, layer.weights_scaling_factor.shape, 1)
|
|
qzeros = torch.split(qzeros,
|
|
layer.weights_scaling_factor.shape[tp_dim],
|
|
tp_dim)[layer.tp_rank]
|
|
qzeros = pad_like(qzeros, layer.zero.shape, 7)
|
|
if USE_INT8_WEIGHT:
|
|
zeros_x_scales = (-qzeros + 128 - 1 * USE_GPTQ) * scales
|
|
else:
|
|
zeros_x_scales = (-qzeros + 8 - 1 * USE_GPTQ) * scales
|
|
zeros_x_scales = zeros_x_scales.to(torch_dtype)
|
|
results.update({
|
|
tllm_key.replace("weight", "weights_scaling_factor"):
|
|
scales,
|
|
tllm_key.replace("weight", "zero"):
|
|
zeros_x_scales,
|
|
})
|
|
elif tllm_key.endswith("weights_scaling_factor"):
|
|
# TODO: Remove reshaping after modelopt optimizes scale shape
|
|
if layer.is_padded:
|
|
raise NotImplementedError(
|
|
"Auto-padding is not Implemented for ModelOpt HF-AWQ.")
|
|
results = {tllm_key: weights}
|
|
elif tllm_key.endswith("prequant_scaling_factor"):
|
|
prequant_scale = weights.to(torch_dtype).reshape(1, -1)
|
|
if layer.is_padded and tp_dim == 1:
|
|
prequant_scale = torch.split(prequant_scale,
|
|
layer.prequant_scaling_factor.shape[1],
|
|
1)[layer.tp_rank]
|
|
prequant_scale = pad_like(prequant_scale,
|
|
layer.prequant_scaling_factor.shape, 0)
|
|
results = {tllm_key: prequant_scale}
|
|
|
|
return results
|
|
|
|
|
|
def postprocess_fp8_rowwise(tllm_key, weights, **kwargs):
|
|
if tllm_key.endswith("per_channel_scale"):
|
|
return {}
|
|
|
|
config = kwargs.get("config", None)
|
|
weights, scales = weights[0::2], weights[1::2]
|
|
|
|
if scales[0] is not None:
|
|
assert all(w.dtype == torch.float8_e4m3fn for w in weights)
|
|
weights = torch.cat(weights, dim=0)
|
|
scales = torch.cat([s.to(torch.float32).flatten() for s in scales])
|
|
return {
|
|
tllm_key: weights,
|
|
tllm_key.replace("weight", "per_channel_scale"): scales
|
|
}
|
|
else:
|
|
x = torch.cat(weights, dim=0).to(torch.float32)
|
|
clamp_val = config.quantization.clamp_val
|
|
if clamp_val is not None:
|
|
# activation range bound.
|
|
x = x.clamp(clamp_val[0], clamp_val[1])
|
|
xmax = x.abs().max(-1, keepdim=True).values
|
|
# minimum scaling factor.
|
|
torch_weight_scales = (xmax / 448.0).clamp(min=1.0 / (448.0 * 512.0))
|
|
out = x / torch_weight_scales
|
|
torch_weight_scales = torch_weight_scales.reshape(-1)
|
|
out = torch.clamp(out, -448, 448)
|
|
processed_torch_weights = out.to(torch.float8_e4m3fn)
|
|
processed_torch_weights = processed_torch_weights.to(
|
|
torch.float8_e4m3fn)
|
|
return {
|
|
tllm_key: processed_torch_weights,
|
|
tllm_key.replace("weight", "per_channel_scale"): torch_weight_scales
|
|
}
|
|
|
|
|
|
def fp4_gemm(input: Tensor,
|
|
input_sf: Tensor,
|
|
weight: Tensor,
|
|
weight_sf: Tensor,
|
|
global_sf: Tensor,
|
|
output_dtype: str | trt.DataType,
|
|
scaling_vector_size: int = 16):
|
|
'''
|
|
Parameters:
|
|
input : Tensor (On GPU)
|
|
The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp4
|
|
input_sf : Tensor (On GPU)
|
|
The input scaling factor tensor. Its shape is [batch_size, seq_len, input_dim / scaling_vector_size] or [num_tokens, input_dim / scaling_vector_size] for remove_input_padding, should be int32 (4 packed)
|
|
weight : Tensor (On GPU)
|
|
The weight tensor. Its shape is [output_dim, input_dim], should be fp4
|
|
weight_sf : Tensor (On GPU)
|
|
The weight scaling factor tensor. Its shape is [output_dim, input_dim / scaling_vector_size], should be fp8
|
|
global_sf : Tensor (On GPU)
|
|
The global scaling factor tensor. Its shape is [1,], should be float32, used as alpha of Gemm.
|
|
output_dtype: str
|
|
output data type
|
|
scaling_vector_size: int
|
|
scaling vector block size
|
|
'''
|
|
if isinstance(output_dtype, str):
|
|
output_dtype = str_dtype_to_trt(output_dtype)
|
|
|
|
fp4_gemm_plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'Fp4Gemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert fp4_gemm_plg_creator is not None
|
|
sv_vec_size = trt.PluginField("sv_vec_size",
|
|
np.array(scaling_vector_size, dtype=np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
output_dtype = trt.PluginField("output_type_id",
|
|
np.array([int(output_dtype)], np.int32),
|
|
trt.PluginFieldType.INT32)
|
|
pfc = trt.PluginFieldCollection([sv_vec_size, output_dtype])
|
|
fp4_gemm_plug = fp4_gemm_plg_creator.create_plugin("fp4_gemm", pfc)
|
|
plug_inputs = [input, input_sf, weight, weight_sf, global_sf]
|
|
plug_inputs = [i.trt_tensor for i in plug_inputs]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, fp4_gemm_plug)
|
|
_add_plugin_info(layer, fp4_gemm_plg_creator, "fp4_gemm", pfc)
|
|
output = _create_tensor(layer.get_output(0), layer)
|
|
return output
|
|
|
|
|
|
def quantize_to_fp4_tensor(input: Tensor, sf_scale: Tensor):
|
|
'''
|
|
Parameters:
|
|
input : Tensor (On GPU)
|
|
The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp16
|
|
sf_scale : Tensor (On GPU)
|
|
The global per-tensor scaling factor. Its shape is [1,], should be float32.
|
|
used to scale SF from input range to fp8 range (448.f / (MaxVal of input / 6.f)).
|
|
output : Tensor (On GPU)
|
|
The output tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be FP4
|
|
output_sf : Tensor (On GPU)
|
|
The input scaling factor tensor. Its shape is [batch_size, seq_len, input_dim / scaling_vector_size] or [num_tokens, input_dim / scaling_vector_size] for remove_input_padding, should be FP8
|
|
'''
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator(
|
|
'QuantizeToFP4', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
|
assert plg_creator is not None
|
|
|
|
pfc = trt.PluginFieldCollection([])
|
|
quantize_plug = plg_creator.create_plugin("quantize_to_fp4_plugin", pfc)
|
|
|
|
plug_inputs = [input.trt_tensor, sf_scale.trt_tensor]
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
|
|
_add_plugin_info(layer, plg_creator, "quantize_to_fp4_plugin", pfc)
|
|
|
|
quantized = _create_tensor(layer.get_output(0), layer)
|
|
scales = _create_tensor(layer.get_output(1), layer)
|
|
return quantized, scales
|
|
|
|
|
|
def dynamic_quantize(
|
|
x: Tensor,
|
|
double_scale: Tensor,
|
|
axis: int = -1,
|
|
block_size: int = 16,
|
|
data_qtype: trt.DataType = trt.fp4,
|
|
scale_qtype: trt.DataType = trt.fp8) -> Tuple[Tensor, Tensor]:
|
|
'''
|
|
Parameters:
|
|
x : Tensor (On GPU)
|
|
The input tensor.
|
|
double_scale : Tensor (On GPU)
|
|
The global per-tensor scaling factor. It should contain only 1 element.
|
|
axis : int
|
|
The axis to quantize. Default is -1 (the last axis).
|
|
block_size : int
|
|
The block size for quantization. Default is 16.
|
|
data_qtype : trt.DataType
|
|
The data type for quantized data. Default is FP4.
|
|
scale_qtype : trt.DataType
|
|
The data type for block scale. Default is FP8.
|
|
Returns:
|
|
A tuple of two tensors: quantized tensor and block scale tensor.
|
|
'''
|
|
if axis < 0:
|
|
axis = len(x.shape) + axis
|
|
dynq = default_trtnet().add_dynamic_quantize(x.trt_tensor, axis, block_size,
|
|
data_qtype, scale_qtype)
|
|
dynq.set_input(1, double_scale.trt_tensor)
|
|
quantized = _create_tensor(dynq.get_output(0), dynq)
|
|
scale = _create_tensor(dynq.get_output(1), dynq)
|
|
return quantized, scale
|
|
|
|
|
|
def block_double_dequantize(x: Tensor,
|
|
scale: Tensor,
|
|
double_scale: Tensor,
|
|
dtype: trt.DataType | str = 'float16') -> Tensor:
|
|
'''
|
|
Parameters:
|
|
x : Tensor (On GPU)
|
|
The input tensor.
|
|
scale : Tensor (On GPU)
|
|
The block scale tensor.
|
|
double_scale : Tensor (On GPU)
|
|
The global per-tensor scaling factor. It should contain only 1 element.
|
|
dtype : trt.DataType | str
|
|
The data type for dequantized data. Default is float32.
|
|
Returns:
|
|
The dequantized tensor.
|
|
'''
|
|
if isinstance(dtype, str):
|
|
dtype = str_dtype_to_trt(dtype)
|
|
dequantize_scale_layer = default_trtnet().add_dequantize(
|
|
scale.trt_tensor, double_scale.trt_tensor, dtype)
|
|
scale = _create_tensor(dequantize_scale_layer.get_output(0),
|
|
dequantize_scale_layer)
|
|
|
|
dequantize_data_layer = default_trtnet().add_dequantize(
|
|
x.trt_tensor, scale.trt_tensor, dtype)
|
|
dequantize_data = _create_tensor(dequantize_data_layer.get_output(0),
|
|
dequantize_data_layer)
|
|
return dequantize_data
|