mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
3316 lines
135 KiB
Python
3316 lines
135 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import tensorrt as trt
|
|
import torch
|
|
|
|
from .._common import default_net, precision
|
|
from .._utils import is_same_dtype, str_dtype_to_torch, trt_dtype_to_torch
|
|
from ..functional import (ACT2FN, AllReduceFusionOp, AllReduceParams,
|
|
AttentionMaskType, PositionEmbeddingType,
|
|
RotaryScalingType, Tensor, allgather, allreduce, cast,
|
|
concat, constant, div, embedding, gemm_allreduce,
|
|
generate_alibi_slopes, gpt_attention, matmul, mul,
|
|
shape, slice, softmax, split, where)
|
|
from ..layers import MropeParams, SpecDecodingParams
|
|
from ..layers.embedding import Embedding
|
|
from ..layers.linear import Linear, RowLinear
|
|
from ..module import Module
|
|
from ..parameter import Parameter
|
|
|
|
# isort: off
|
|
from .functional import (
|
|
block_double_dequantize, dequantize, dynamic_quantize, fp4_gemm,
|
|
quantize_to_fp4_tensor, fp8_rowwise_gemm, fp8_rowwise_rms_norm,
|
|
postprocess_fp8_rowwise, postprocess_weight_only,
|
|
postprocess_weight_only_groupwise, quantize, quantize_fp8_per_token,
|
|
quantize_per_token, quantize_tensor, validate_group_size, smooth_quant_gemm,
|
|
smooth_quant_layer_norm, smooth_quant_rms_norm,
|
|
weight_only_groupwise_quant_matmul, weight_only_quant_matmul,
|
|
qserve_gemm_per_group, qserve_gemm_per_channel)
|
|
# isort: on
|
|
from .mode import GroupwiseQuantAlgo, QuantMode
|
|
|
|
|
|
class Quantize(Module):
|
|
"""
|
|
Quantize Layer
|
|
For per-tensor mode, the scaling factor is a scalar.
|
|
For per-channel mode, the scaling factor is a vector.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_dtype: str = 'int8',
|
|
scaling_factor_dtype: str = 'float32',
|
|
in_channels: int = -1,
|
|
axis=-1,
|
|
) -> None:
|
|
super().__init__()
|
|
self.scaling_factor = Parameter(shape=(in_channels, ) if axis != -1 else
|
|
(),
|
|
dtype=scaling_factor_dtype)
|
|
self.output_dtype = output_dtype
|
|
self.axis = axis
|
|
|
|
def forward(self, x):
|
|
return quantize(x, self.scaling_factor.value, self.output_dtype,
|
|
self.axis)
|
|
|
|
|
|
class QuantizePerToken(Module):
|
|
"""
|
|
Quantize Per Token and compute dynamic scales for SmoothQuant
|
|
"""
|
|
|
|
def forward(self, x):
|
|
return quantize_per_token(x)
|
|
|
|
|
|
class Dequantize(Module):
|
|
"""
|
|
Dequantize Layer.
|
|
"""
|
|
|
|
def __init__(self, axis: int = -1) -> None:
|
|
super().__init__()
|
|
self.scaling_factor = Parameter(shape=(), dtype='float32')
|
|
self.axis = axis
|
|
|
|
def forward(self, input):
|
|
return dequantize(input, self.scaling_factor.value, self.axis)
|
|
|
|
|
|
class SmoothQuantLinear(Linear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
gather_output=True,
|
|
quant_mode=QuantMode(0),
|
|
prefer_managed_weight=True):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
|
|
if not quant_mode.has_act_and_weight_quant():
|
|
raise ValueError(
|
|
"SmoothQuant Linear has to have act+weight quantization mode set"
|
|
)
|
|
|
|
weights_dtype = dtype
|
|
if quant_mode.has_act_and_weight_quant():
|
|
weights_dtype = "int8"
|
|
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=weights_dtype,
|
|
prefer_managed=self.prefer_managed_weight)
|
|
|
|
if quant_mode.has_act_and_weight_quant():
|
|
scale_shape = (1, self.out_features
|
|
) if quant_mode.has_per_channel_scaling() else (1, 1)
|
|
self.per_channel_scale = Parameter(shape=scale_shape,
|
|
dtype="float32")
|
|
|
|
if quant_mode.has_act_static_scaling():
|
|
self.act_scale = Parameter(shape=(1, 1), dtype="float32")
|
|
|
|
self.quant_mode = quant_mode
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
assert lora_runtime_params is None, "lora is not supported on SmoothQuantLinear now"
|
|
if self.quant_mode.has_act_static_scaling():
|
|
per_token_scale = self.act_scale.value
|
|
else:
|
|
# If we are in SmoothQuant with dynamic activation scaling,
|
|
# input x has to be a tuple of int8 tensor and fp32 scaling factors
|
|
x, per_token_scale = x
|
|
x = smooth_quant_gemm(x, self.weight.value, per_token_scale,
|
|
self.per_channel_scale.value,
|
|
self.quant_mode.has_per_token_dynamic_scaling(),
|
|
self.quant_mode.has_per_channel_scaling(),
|
|
self.dtype)
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
|
|
# [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
|
|
x = allgather(x, self.tp_group, gather_dim=1)
|
|
|
|
return x
|
|
|
|
|
|
SmoothQuantColumnLinear = SmoothQuantLinear
|
|
|
|
|
|
class SmoothQuantRowLinear(RowLinear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
prefer_managed_weight=True,
|
|
):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
if not quant_mode.has_act_and_weight_quant():
|
|
raise ValueError(
|
|
"SmoothQuant Linear has to have act+weight quantization mode set"
|
|
)
|
|
weights_dtype = dtype
|
|
if quant_mode.has_act_and_weight_quant():
|
|
weights_dtype = "int8"
|
|
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=weights_dtype,
|
|
prefer_managed=self.prefer_managed_weight)
|
|
self.smoother = Parameter(shape=(1, self.in_features), dtype="float32")
|
|
if quant_mode.has_act_and_weight_quant():
|
|
scale_shape = (1, self.out_features
|
|
) if quant_mode.has_per_channel_scaling() else (1, 1)
|
|
self.per_channel_scale = Parameter(shape=scale_shape,
|
|
dtype="float32")
|
|
|
|
if quant_mode.has_act_static_scaling():
|
|
self.act_scale = Parameter(shape=(1, 1), dtype="float32")
|
|
|
|
self.quant_mode = quant_mode
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
assert lora_runtime_params is None, "lora is not supported on SmoothQuantRowLinear now"
|
|
if self.quant_mode.has_act_static_scaling():
|
|
per_token_scale = self.act_scale.value
|
|
else:
|
|
x, per_token_scale = x
|
|
x = smooth_quant_gemm(x, self.weight.value, per_token_scale,
|
|
self.per_channel_scale.value,
|
|
self.quant_mode.has_per_token_dynamic_scaling(),
|
|
self.quant_mode.has_per_channel_scaling(),
|
|
self.dtype)
|
|
|
|
if self.tp_size > 1 and self.tp_group is not None:
|
|
need_bias = self.bias is not None
|
|
fuse_bias_into_all_reduce = need_bias and (
|
|
all_reduce_params
|
|
is not None) and (all_reduce_params.fusion_op
|
|
== AllReduceFusionOp.RESIDUAL_RMS_NORM)
|
|
if fuse_bias_into_all_reduce:
|
|
all_reduce_params.bias = self.bias.value
|
|
x = allreduce(x, self.tp_group, all_reduce_params=all_reduce_params)
|
|
if need_bias and not fuse_bias_into_all_reduce:
|
|
x = x + self.bias.value
|
|
return x
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
return x
|
|
|
|
|
|
class SmoothQuantLayerNorm(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
normalized_shape,
|
|
eps=1e-05,
|
|
elementwise_affine=True,
|
|
bias=True,
|
|
dtype=None,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
super().__init__()
|
|
if isinstance(normalized_shape, int):
|
|
normalized_shape = (normalized_shape, )
|
|
if not quant_mode.has_act_and_weight_quant():
|
|
raise ValueError(
|
|
"SmoothQuant layer norm has to have some quantization mode set")
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
if bias:
|
|
self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
|
|
self.eps = eps
|
|
self.dtype = dtype
|
|
self.quant_mode = quant_mode
|
|
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
self.scale_to_int = Parameter(shape=(1, ), dtype=dtype)
|
|
else:
|
|
self.register_parameter('scale_to_int', None)
|
|
|
|
def forward(self, x):
|
|
weight = None if self.weight is None else self.weight.value
|
|
bias = None if self.bias is None else self.bias.value
|
|
scale = None if self.scale_to_int is None else self.scale_to_int.value
|
|
return smooth_quant_layer_norm(
|
|
x,
|
|
self.normalized_shape,
|
|
weight,
|
|
bias,
|
|
scale,
|
|
self.eps,
|
|
dynamic_act_scaling=self.quant_mode.has_per_token_dynamic_scaling())
|
|
|
|
|
|
class SmoothQuantRmsNorm(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
normalized_shape,
|
|
eps=1e-06,
|
|
elementwise_affine=True,
|
|
dtype=None,
|
|
quant_mode=QuantMode(0),
|
|
bias=False,
|
|
clamp_val=None,
|
|
):
|
|
super().__init__()
|
|
if isinstance(normalized_shape, int):
|
|
normalized_shape = (normalized_shape, )
|
|
if not quant_mode.has_act_and_weight_quant():
|
|
raise ValueError(
|
|
"SmoothQuant Rms norm has to have some quantization mode set")
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
|
|
if bias:
|
|
self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
if clamp_val:
|
|
if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
|
|
raise ValueError(f'unsupported clamp_val {clamp_val}')
|
|
self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
|
|
dtype='float32',
|
|
is_buffer=True)
|
|
else:
|
|
self.register_parameter('clamp_val', None)
|
|
|
|
self.eps = eps
|
|
self.dtype = dtype
|
|
self.quant_mode = quant_mode
|
|
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
self.scale_to_int = Parameter(shape=(1, ), dtype=dtype)
|
|
else:
|
|
self.register_parameter('scale_to_int', None)
|
|
|
|
def forward(self, x):
|
|
weight = None if self.weight is None else self.weight.value
|
|
bias = None if self.bias is None else self.bias.value
|
|
scale = None if self.scale_to_int is None else self.scale_to_int.value
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
return smooth_quant_rms_norm(
|
|
x,
|
|
self.normalized_shape,
|
|
weight,
|
|
bias,
|
|
scale,
|
|
clamp_val,
|
|
self.eps,
|
|
dynamic_act_scaling=self.quant_mode.has_per_token_dynamic_scaling())
|
|
|
|
|
|
class QServeW4A8Linear(Linear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
gather_output=True,
|
|
quant_mode=QuantMode(0)):
|
|
assert dtype == "float16" # Currently the kernel only supports float16 output
|
|
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output)
|
|
|
|
self.quant_mode = quant_mode
|
|
assert self.quant_mode.is_qserve_w4a8()
|
|
|
|
# Only support g128 now.
|
|
if self.quant_mode.has_per_group_scaling():
|
|
self.group_size = 128
|
|
else:
|
|
self.group_size = -1
|
|
|
|
self.weight = Parameter(shape=(self.out_features,
|
|
self.in_features // 2),
|
|
dtype="int8")
|
|
|
|
self.s1_scales = Parameter(shape=(self.out_features, ), dtype="float16")
|
|
|
|
if self.group_size == -1:
|
|
self.s1_szeros = Parameter(shape=(self.out_features, ),
|
|
dtype="float16")
|
|
else:
|
|
self.s2_scales = Parameter(
|
|
shape=(self.in_features // self.group_size, self.out_features),
|
|
dtype="int8")
|
|
self.s2_szeros = Parameter(
|
|
shape=(self.in_features // self.group_size, self.out_features),
|
|
dtype="int8")
|
|
|
|
def forward(self, x):
|
|
if self.group_size == -1:
|
|
x, per_token_scale, per_token_sum = x
|
|
x = qserve_gemm_per_channel(x, per_token_scale, per_token_sum,
|
|
self.weight.value, self.s1_scales.value,
|
|
self.s1_szeros.value)
|
|
|
|
else:
|
|
x, per_token_scale = x
|
|
x = qserve_gemm_per_group(x, per_token_scale, self.weight.value,
|
|
self.s1_scales.value,
|
|
self.s2_scales.value,
|
|
self.s2_szeros.value, self.group_size)
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
|
|
# [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
|
|
x = allgather(x, self.tp_group, gather_dim=1)
|
|
return x
|
|
|
|
|
|
QServeW4A8ColumnLinear = QServeW4A8Linear
|
|
|
|
|
|
class QServeW4A8RowLinear(RowLinear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
assert dtype == "float16" # Currently the kernel only supports float16 output
|
|
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size)
|
|
|
|
self.quant_mode = quant_mode
|
|
assert self.quant_mode.is_qserve_w4a8()
|
|
# Only supports 128g now.
|
|
if self.quant_mode.has_per_group_scaling():
|
|
self.group_size = 128
|
|
else:
|
|
self.group_size = -1
|
|
|
|
self.weight = Parameter(shape=(self.out_features,
|
|
self.in_features // 2),
|
|
dtype="int8")
|
|
|
|
self.s1_scales = Parameter(shape=(self.out_features, ), dtype="float16")
|
|
|
|
if self.group_size == -1:
|
|
self.s1_szeros = Parameter(shape=(self.out_features, ),
|
|
dtype="float16")
|
|
else:
|
|
self.s2_scales = Parameter(
|
|
shape=(self.in_features // self.group_size, self.out_features),
|
|
dtype="int8")
|
|
self.s2_szeros = Parameter(
|
|
shape=(self.in_features // self.group_size, self.out_features),
|
|
dtype="int8")
|
|
|
|
def forward(self, x, all_reduce_params=None):
|
|
if self.group_size == -1:
|
|
x, per_token_scale, per_token_sum = x
|
|
x = qserve_gemm_per_channel(x, per_token_scale, per_token_sum,
|
|
self.weight.value, self.s1_scales.value,
|
|
self.s1_szeros.value)
|
|
else:
|
|
x, per_token_scale = x
|
|
x = qserve_gemm_per_group(x, per_token_scale, self.weight.value,
|
|
self.s1_scales.value,
|
|
self.s2_scales.value,
|
|
self.s2_szeros.value, self.group_size)
|
|
|
|
if self.tp_size > 1 and self.tp_group is not None:
|
|
need_bias = self.bias is not None
|
|
fuse_bias_into_all_reduce = need_bias and (
|
|
all_reduce_params
|
|
is not None) and (all_reduce_params.fusion_op
|
|
== AllReduceFusionOp.RESIDUAL_RMS_NORM)
|
|
if fuse_bias_into_all_reduce:
|
|
all_reduce_params.bias = self.bias.value
|
|
x = allreduce(x, self.tp_group, all_reduce_params=all_reduce_params)
|
|
if need_bias and not fuse_bias_into_all_reduce:
|
|
x = x + self.bias.value
|
|
return x
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
return x
|
|
|
|
|
|
class Fp8RowwiseRmsNorm(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
normalized_shape,
|
|
eps=1e-06,
|
|
elementwise_affine=True,
|
|
dtype=None,
|
|
quant_mode=QuantMode(0),
|
|
bias=False,
|
|
clamp_val=None,
|
|
):
|
|
super().__init__()
|
|
if isinstance(normalized_shape, int):
|
|
normalized_shape = (normalized_shape, )
|
|
if not quant_mode.has_fp8_rowwise():
|
|
raise ValueError(
|
|
"Fp8 Rowwise Rms norm has to have some quantization mode set")
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
|
|
if bias:
|
|
self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
if clamp_val:
|
|
if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
|
|
raise ValueError(f'unsupported clamp_val {clamp_val}')
|
|
self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
|
|
dtype='float32',
|
|
is_buffer=True)
|
|
else:
|
|
self.register_parameter('clamp_val', None)
|
|
|
|
self.eps = eps
|
|
self.dtype = dtype
|
|
self.quant_mode = quant_mode
|
|
|
|
def forward(self, x):
|
|
weight = None if self.weight is None else self.weight.value
|
|
bias = None if self.bias is None else self.bias.value
|
|
scale = None
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
return fp8_rowwise_rms_norm(
|
|
x,
|
|
self.normalized_shape,
|
|
weight,
|
|
bias,
|
|
scale,
|
|
clamp_val,
|
|
self.eps,
|
|
dynamic_act_scaling=self.quant_mode.has_fp8_rowwise())
|
|
|
|
|
|
class Fp8RowwiseLinear(Linear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
gather_output=True,
|
|
quant_mode=QuantMode(0)):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output)
|
|
|
|
if not quant_mode.has_fp8_rowwise():
|
|
raise ValueError(
|
|
"Fp8 Rowwise Linear has to have act+weight quantization mode set"
|
|
)
|
|
|
|
weights_dtype = dtype
|
|
if quant_mode.has_fp8_rowwise():
|
|
weights_dtype = "fp8"
|
|
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=weights_dtype)
|
|
|
|
if quant_mode.has_fp8_rowwise():
|
|
self.per_channel_scale = Parameter(shape=(self.out_features, ),
|
|
dtype="float32")
|
|
|
|
self.quant_mode = quant_mode
|
|
self.tllm_to_externel_key_dict = {"weight": ["weight", "weight_scale"]}
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
assert lora_runtime_params is None, "lora is not supported on SmoothQuantLinear now"
|
|
x, per_token_scale = x
|
|
x = fp8_rowwise_gemm(x, self.weight.value, per_token_scale,
|
|
self.per_channel_scale.value,
|
|
self.quant_mode.has_per_token_dynamic_scaling(),
|
|
self.quant_mode.has_per_channel_scaling())
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
|
|
# [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
|
|
x = allgather(x, self.tp_group, gather_dim=1)
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if "per_channel_scale" in tllm_key:
|
|
return {}
|
|
return postprocess_fp8_rowwise(tllm_key, weights, **kwargs)
|
|
|
|
|
|
Fp8RowwiseColumnLinear = Fp8RowwiseLinear
|
|
|
|
|
|
class Fp8RowwiseRowLinear(RowLinear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size)
|
|
if not quant_mode.has_fp8_rowwise():
|
|
raise ValueError(
|
|
"Fp8 Rowwise Linear has to have act+weight quantization mode set"
|
|
)
|
|
weights_dtype = dtype
|
|
if quant_mode.has_fp8_rowwise():
|
|
weights_dtype = "fp8"
|
|
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=weights_dtype)
|
|
if quant_mode.has_fp8_rowwise():
|
|
self.per_channel_scale = Parameter(shape=(self.out_features, ),
|
|
dtype="float32")
|
|
|
|
self.quant_mode = quant_mode
|
|
self.tllm_to_externel_key_dict = {"weight": ["weight", "weight_scale"]}
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
assert lora_runtime_params is None, "lora is not supported on SmoothQuantRowLinear now"
|
|
x, per_token_scale = x
|
|
x = fp8_rowwise_gemm(x, self.weight.value, per_token_scale,
|
|
self.per_channel_scale.value,
|
|
self.quant_mode.has_fp8_rowwise(),
|
|
self.quant_mode.has_fp8_rowwise())
|
|
|
|
if self.tp_size > 1 and self.tp_group is not None:
|
|
need_bias = self.bias is not None
|
|
fuse_bias_into_all_reduce = need_bias and (
|
|
all_reduce_params
|
|
is not None) and (all_reduce_params.fusion_op
|
|
== AllReduceFusionOp.RESIDUAL_RMS_NORM)
|
|
if fuse_bias_into_all_reduce:
|
|
all_reduce_params.bias = self.bias.value
|
|
x = allreduce(x, self.tp_group, all_reduce_params=all_reduce_params)
|
|
if need_bias and not fuse_bias_into_all_reduce:
|
|
x = x + self.bias.value
|
|
return x
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if "per_channel_scale" in tllm_key:
|
|
return {}
|
|
return postprocess_fp8_rowwise(tllm_key, weights, **kwargs)
|
|
|
|
|
|
class WeightOnlyQuantLinear(Linear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
gather_output=True,
|
|
quant_mode=QuantMode.use_weight_only(),
|
|
transa=False,
|
|
transb=False,
|
|
is_qkv=False,
|
|
prefer_managed_weight=True,
|
|
):
|
|
multiple = 64 * tp_size
|
|
self.is_padded = False
|
|
if in_features % multiple > 0:
|
|
in_features = math.ceil(in_features / multiple) * multiple
|
|
self.is_padded = True
|
|
if out_features % multiple > 0:
|
|
out_features = math.ceil(out_features / multiple) * multiple
|
|
self.is_padded = True
|
|
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output,
|
|
is_qkv=is_qkv,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
if quant_mode.is_int8_weight_only():
|
|
self.weight_only_quant_mode = 1
|
|
quant_type_size_in_bits = 8
|
|
elif quant_mode.is_int4_weight_only():
|
|
self.weight_only_quant_mode = 2
|
|
quant_type_size_in_bits = 4
|
|
# we use a fake tensor with data_type = int8
|
|
self.weight = Parameter(shape=(self.in_features,
|
|
int(self.out_features *
|
|
quant_type_size_in_bits / 8)),
|
|
dtype="int8",
|
|
prefer_managed=self.prefer_managed_weight)
|
|
|
|
scale_shape = (self.out_features, )
|
|
self.per_channel_scale = Parameter(shape=scale_shape, dtype=dtype)
|
|
|
|
self.transa = transa
|
|
self.transb = transb
|
|
self.tp_rank = tp_rank
|
|
if self.is_padded:
|
|
self.tp_dim = -1
|
|
self.quant_mode = quant_mode
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
# ootb has not supported int4 yet.
|
|
if self.weight_only_quant_mode == 2 and not default_net(
|
|
).plugin_config.weight_only_quant_matmul_plugin:
|
|
raise TypeError(
|
|
"Int4 Weight-only Quant MatMul is only supported with plugin")
|
|
hidden_state = x
|
|
x = weight_only_quant_matmul(x, self.weight.value,
|
|
self.per_channel_scale.value,
|
|
self.weight_only_quant_mode, self.dtype,
|
|
self.transa, self.transb)
|
|
|
|
if default_net(
|
|
).plugin_config.lora_plugin and lora_runtime_params is not None:
|
|
x = x + self.lora(hidden_state,
|
|
lora_runtime_params=lora_runtime_params)
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
|
|
# [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
|
|
x = allgather(x, self.tp_group, gather_dim=1)
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if "per_channel_scale" in tllm_key:
|
|
return {}
|
|
weights = super().postprocess(tllm_key, weights, **kwargs)[tllm_key]
|
|
weights = weights.to(str_dtype_to_torch(self.dtype))
|
|
return postprocess_weight_only(
|
|
tllm_key, weights,
|
|
torch.int8 if self.weight_only_quant_mode == 1 else torch.quint4x2,
|
|
self)
|
|
|
|
|
|
WeightOnlyQuantColumnLinear = WeightOnlyQuantLinear
|
|
|
|
|
|
class WeightOnlyQuantRowLinear(RowLinear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
quant_mode=QuantMode.use_weight_only(),
|
|
prefer_managed_weight=True,
|
|
is_expert=False,
|
|
):
|
|
multiple = 64 * tp_size
|
|
self.is_padded = False
|
|
if in_features % multiple > 0:
|
|
in_features = math.ceil(in_features / multiple) * multiple
|
|
self.is_padded = True
|
|
if out_features % multiple > 0:
|
|
out_features = math.ceil(out_features / multiple) * multiple
|
|
self.is_padded = True
|
|
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
prefer_managed_weight=prefer_managed_weight,
|
|
is_expert=is_expert)
|
|
if quant_mode.is_int8_weight_only():
|
|
self.weight_only_quant_mode = 1
|
|
elif quant_mode.is_int4_weight_only():
|
|
self.weight_only_quant_mode = 2
|
|
#we use a fake tensor with data_type = int8
|
|
self.weight = Parameter(shape=(self.in_features,
|
|
int(self.out_features /
|
|
self.weight_only_quant_mode)),
|
|
dtype="int8",
|
|
prefer_managed=prefer_managed_weight)
|
|
self.per_channel_scale = Parameter(shape=(self.out_features, ),
|
|
dtype=dtype)
|
|
self.tp_rank = tp_rank
|
|
if self.is_padded:
|
|
self.tp_dim = -1
|
|
self.quant_mode = quant_mode
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
hidden_state = x
|
|
x = weight_only_quant_matmul(x, self.weight.value,
|
|
self.per_channel_scale.value,
|
|
self.weight_only_quant_mode, self.dtype)
|
|
|
|
if default_net(
|
|
).plugin_config.lora_plugin and lora_runtime_params is not None:
|
|
x = x + self.lora(hidden_state,
|
|
lora_runtime_params=lora_runtime_params)
|
|
|
|
if self.tp_size > 1 and self.tp_group is not None:
|
|
need_bias = self.bias is not None
|
|
fuse_bias_into_all_reduce = need_bias and (
|
|
all_reduce_params
|
|
is not None) and (all_reduce_params.fusion_op
|
|
== AllReduceFusionOp.RESIDUAL_RMS_NORM)
|
|
if fuse_bias_into_all_reduce:
|
|
all_reduce_params.bias = self.bias.value
|
|
if not self.is_expert:
|
|
x = allreduce(x,
|
|
self.tp_group,
|
|
all_reduce_params=all_reduce_params)
|
|
if need_bias and not fuse_bias_into_all_reduce:
|
|
bias = cast(self.bias.value, x.dtype)
|
|
x = x + bias
|
|
else:
|
|
if need_bias and not fuse_bias_into_all_reduce:
|
|
bias = cast(self.bias.value, x.dtype)
|
|
x = x + bias / self.tp_size
|
|
return x
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if "per_channel_scale" in tllm_key:
|
|
return {}
|
|
weights = weights.to(str_dtype_to_torch(self.dtype))
|
|
return postprocess_weight_only(
|
|
tllm_key, weights,
|
|
torch.int8 if self.weight_only_quant_mode == 1 else torch.quint4x2,
|
|
self)
|
|
|
|
|
|
class WeightOnlyQuantEmbedding(Embedding):
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
dtype: Optional[str] = None,
|
|
tp_size: int = 1,
|
|
tp_group: Optional[list] = None,
|
|
sharding_dim: int = 0,
|
|
tp_rank: Optional[int] = None,
|
|
quant_mode=QuantMode.use_weight_only(),
|
|
):
|
|
super().__init__(
|
|
num_embeddings,
|
|
embedding_dim,
|
|
dtype, # dtype,
|
|
tp_size,
|
|
tp_group,
|
|
sharding_dim,
|
|
tp_rank)
|
|
# only support int8 wo now
|
|
# TODO support int4 wo
|
|
self.quant_mode = quant_mode
|
|
self.per_token_scale = Parameter(shape=(self.num_embeddings, ),
|
|
dtype=dtype)
|
|
|
|
if sharding_dim == 1:
|
|
self.weight = Parameter(shape=(self.num_embeddings,
|
|
self.embedding_dim // self.tp_size),
|
|
dtype="int8")
|
|
elif sharding_dim == 0:
|
|
self.weight = Parameter(shape=(math.ceil(
|
|
self.num_embeddings / self.tp_size), self.embedding_dim),
|
|
dtype="int8")
|
|
|
|
def forward(self, x):
|
|
result = embedding(x,
|
|
self.weight.value,
|
|
tp_size=self.tp_size,
|
|
tp_group=self.tp_group,
|
|
sharding_dim=self.sharding_dim,
|
|
tp_rank=self.tp_rank,
|
|
per_token_scale=self.per_token_scale.value)
|
|
|
|
return result
|
|
|
|
|
|
def unpack_int32_into_int8(w_packed):
|
|
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
|
|
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
|
|
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
|
|
w_packed_int4x2.shape[1] * 2,
|
|
dtype=torch.int8)
|
|
w_unpacked[:, ::2] = w_packed_int4x2 % 16
|
|
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
|
|
w_unpacked = w_unpacked.view(-1, 8)[:, [0, 4, 1, 5, 2, 6, 3, 7]].view(
|
|
w_unpacked.shape)
|
|
return w_unpacked.contiguous()
|
|
|
|
|
|
class WeightOnlyGroupwiseQuantLinear(Linear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
group_size=128,
|
|
pre_quant_scale=False,
|
|
zero=False,
|
|
bias=False,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
gather_output=True,
|
|
use_w4a8_awq=False,
|
|
use_int8_weight=False,
|
|
is_qkv=False,
|
|
prefer_managed_weight=True,
|
|
):
|
|
multiple = max((128 if use_w4a8_awq else 64), group_size) * tp_size
|
|
self.is_padded = False
|
|
if in_features % multiple > 0:
|
|
in_features = math.ceil(in_features / multiple) * multiple
|
|
self.is_padded = True
|
|
if out_features % multiple > 0:
|
|
out_features = math.ceil(out_features / multiple) * multiple
|
|
self.is_padded = True
|
|
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output,
|
|
is_qkv=is_qkv,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
|
|
self.quant_algo = (
|
|
use_int8_weight * GroupwiseQuantAlgo.INT8_WEIGHT +
|
|
use_w4a8_awq * GroupwiseQuantAlgo.W4A8_ALPHA +
|
|
pre_quant_scale * GroupwiseQuantAlgo.PRE_QUANT_SCALE +
|
|
zero * GroupwiseQuantAlgo.ZERO + bias * GroupwiseQuantAlgo.BIAS)
|
|
self.group_size = group_size
|
|
# packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
|
|
pack_ratio = 2 if use_int8_weight else 4
|
|
self.weight = Parameter(shape=(self.in_features,
|
|
self.out_features // pack_ratio),
|
|
dtype=dtype,
|
|
prefer_managed=self.prefer_managed_weight)
|
|
|
|
scale_shape = (self.in_features // group_size, self.out_features)
|
|
self.weights_scaling_factor = Parameter(shape=scale_shape, dtype=dtype)
|
|
self.tp_rank = tp_rank
|
|
if self.is_padded:
|
|
self.tp_dim = -1
|
|
self.pre_quant_scale = pre_quant_scale
|
|
self.use_w4a8_awq = use_w4a8_awq
|
|
|
|
if pre_quant_scale:
|
|
self.prequant_scaling_factor = Parameter(shape=(1,
|
|
self.in_features),
|
|
dtype=dtype)
|
|
else:
|
|
self.register_parameter('prequant_scaling_factor', None)
|
|
|
|
if zero:
|
|
self.zero = Parameter(shape=scale_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('zero', None)
|
|
|
|
if use_w4a8_awq:
|
|
self.alpha = Parameter(shape=(1, ), dtype="float32")
|
|
else:
|
|
self.register_parameter('alpha', None)
|
|
|
|
validate_group_size(self)
|
|
if pre_quant_scale:
|
|
self.tllm_to_externel_key_dict = {
|
|
"weights_scaling_factor": "weight_scale",
|
|
"prequant_scaling_factor": "input_quantizer._pre_quant_scale",
|
|
} # AWQ
|
|
else:
|
|
self.tllm_to_externel_key_dict = {
|
|
"weight": ["qweight", "scales", "qzeros"]
|
|
} # GPTQ
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
pre_quant_scale = self.prequant_scaling_factor.value if self.prequant_scaling_factor else None
|
|
zero = self.zero.value if self.zero else None
|
|
bias = self.bias.value if self.bias else None
|
|
alpha = self.alpha if self.alpha else None
|
|
|
|
hidden_state = x
|
|
x = weight_only_groupwise_quant_matmul(
|
|
x, pre_quant_scale, self.weight.value,
|
|
self.weights_scaling_factor.value, zero, bias, alpha,
|
|
self.quant_algo, self.group_size, self.dtype)
|
|
|
|
if default_net(
|
|
).plugin_config.lora_plugin and lora_runtime_params is not None:
|
|
x = x + self.lora(hidden_state,
|
|
lora_runtime_params=lora_runtime_params)
|
|
|
|
if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
|
|
# [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
|
|
x = allgather(x, self.tp_group, gather_dim=1)
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if tllm_key.endswith("zero") or (
|
|
self.prequant_scaling_factor is None
|
|
and tllm_key.endswith("weights_scaling_factor")):
|
|
return {}
|
|
torch_dtype = str_dtype_to_torch(self.dtype)
|
|
return postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype,
|
|
self, **kwargs)
|
|
|
|
|
|
WeightOnlyGroupwiseQuantColumnLinear = WeightOnlyGroupwiseQuantLinear
|
|
|
|
|
|
class WeightOnlyGroupwiseQuantRowLinear(RowLinear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
group_size=128,
|
|
pre_quant_scale=False,
|
|
zero=False,
|
|
bias=False,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
use_w4a8_awq=False,
|
|
use_int8_weight=False,
|
|
prefer_managed_weight=True):
|
|
multiple = max((128 if use_w4a8_awq else 64), group_size) * tp_size
|
|
self.is_padded = False
|
|
if in_features % multiple > 0:
|
|
in_features = math.ceil(in_features / multiple) * multiple
|
|
self.is_padded = True
|
|
if out_features % multiple > 0:
|
|
out_features = math.ceil(out_features / multiple) * multiple
|
|
self.is_padded = True
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
|
|
self.quant_algo = (
|
|
use_int8_weight * GroupwiseQuantAlgo.INT8_WEIGHT +
|
|
use_w4a8_awq * GroupwiseQuantAlgo.W4A8_ALPHA +
|
|
pre_quant_scale * GroupwiseQuantAlgo.PRE_QUANT_SCALE +
|
|
zero * GroupwiseQuantAlgo.ZERO + bias * GroupwiseQuantAlgo.BIAS)
|
|
self.group_size = group_size
|
|
# packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
|
|
pack_ratio = 2 if use_int8_weight else 4
|
|
self.weight = Parameter(shape=(self.in_features,
|
|
self.out_features // pack_ratio),
|
|
dtype=dtype,
|
|
prefer_managed=self.prefer_managed_weight)
|
|
scale_shape = (self.in_features // group_size, self.out_features)
|
|
self.weights_scaling_factor = Parameter(shape=scale_shape, dtype=dtype)
|
|
self.tp_rank = tp_rank
|
|
if self.is_padded:
|
|
self.tp_dim = -1
|
|
self.pre_quant_scale = pre_quant_scale
|
|
self.use_w4a8_awq = use_w4a8_awq
|
|
|
|
if pre_quant_scale:
|
|
self.prequant_scaling_factor = Parameter(shape=(1,
|
|
self.in_features),
|
|
dtype=dtype)
|
|
else:
|
|
self.register_parameter('prequant_scaling_factor', None)
|
|
|
|
if zero:
|
|
self.zero = Parameter(shape=scale_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('zero', None)
|
|
|
|
if use_w4a8_awq:
|
|
self.alpha = Parameter(shape=(1, ), dtype="float32")
|
|
self.activation_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
else:
|
|
self.register_parameter('alpha', None)
|
|
self.register_parameter('activation_scaling_factor', None)
|
|
|
|
validate_group_size(self)
|
|
if pre_quant_scale:
|
|
self.tllm_to_externel_key_dict = {
|
|
"weights_scaling_factor": "weight_scale",
|
|
"prequant_scaling_factor": "input_quantizer._pre_quant_scale",
|
|
} # AWQ
|
|
else:
|
|
self.tllm_to_externel_key_dict = {
|
|
"weight": ["qweight", "scales", "qzeros"]
|
|
} # GPTQ
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
pre_quant_scale = self.prequant_scaling_factor.value if self.prequant_scaling_factor else None
|
|
zero = self.zero.value if self.zero else None
|
|
bias = self.bias.value if self.bias else None
|
|
alpha = self.alpha if self.alpha else None
|
|
activation_scaling_factor = self.activation_scaling_factor.value if self.activation_scaling_factor else None
|
|
|
|
if self.alpha and x.dtype == trt.fp8:
|
|
x = dequantize(x, activation_scaling_factor, -1, self.dtype)
|
|
|
|
hidden_state = x
|
|
x = weight_only_groupwise_quant_matmul(
|
|
x, pre_quant_scale, self.weight.value,
|
|
self.weights_scaling_factor.value, zero, bias, alpha,
|
|
self.quant_algo, self.group_size, self.dtype)
|
|
|
|
if default_net(
|
|
).plugin_config.lora_plugin and lora_runtime_params is not None:
|
|
x = x + self.lora(hidden_state,
|
|
lora_runtime_params=lora_runtime_params)
|
|
|
|
if self.tp_size > 1 and self.tp_group is not None:
|
|
x = allreduce(x, self.tp_group, all_reduce_params=all_reduce_params)
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if tllm_key.endswith("zero") or (
|
|
self.prequant_scaling_factor is None
|
|
and tllm_key.endswith("weights_scaling_factor")):
|
|
return {}
|
|
torch_dtype = str_dtype_to_torch(self.dtype)
|
|
return postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype,
|
|
self, **kwargs)
|
|
|
|
|
|
class SmoothQuantMLP(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
super().__init__()
|
|
if hidden_act not in ACT2FN:
|
|
raise ValueError(
|
|
'unsupported activation function: {}'.format(hidden_act))
|
|
fc_output_size = 2 * ffn_hidden_size if hidden_act == 'swiglu' else ffn_hidden_size
|
|
self.fc = SmoothQuantColumnLinear(hidden_size,
|
|
fc_output_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
self.proj = SmoothQuantRowLinear(ffn_hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
self.hidden_act = hidden_act
|
|
self.quant_mode = quant_mode
|
|
self.dtype = dtype
|
|
|
|
if self.quant_mode.has_act_static_scaling():
|
|
self.quantization_scaling_factor = Parameter(shape=(1, ),
|
|
dtype='float32')
|
|
else:
|
|
self.register_parameter('quantization_scaling_factor', None)
|
|
|
|
def forward(self, hidden_states, lora_layer_params=None):
|
|
|
|
inter = self.fc(hidden_states)
|
|
inter = ACT2FN[self.hidden_act](inter)
|
|
value = cast(self.proj.smoother.value, inter.dtype)
|
|
inter = inter / value
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
if self.quant_mode.has_act_static_scaling():
|
|
# Avoid quantization layers as it breaks int8 plugins
|
|
inter = quantize_tensor(inter,
|
|
self.quantization_scaling_factor.value)
|
|
else:
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
inter = quantize_per_token(inter)
|
|
output = self.proj(inter)
|
|
return output
|
|
|
|
|
|
class Int8SmoothQuantRowLinear(RowLinear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
prefer_managed_weight=True):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
self.activation_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
self.weights_scaling_factor = Parameter(shape=(self.out_features, ),
|
|
dtype=trt.float32)
|
|
self.prequant_scaling_factor = Parameter(shape=(self.in_features, ),
|
|
dtype=dtype)
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=trt.int8,
|
|
prefer_managed=self.prefer_managed_weight)
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
lora_hidden_state = x if lora_runtime_params is not None else None
|
|
if default_net().strongly_typed:
|
|
assert is_same_dtype(
|
|
x.dtype,
|
|
self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}"
|
|
x = mul(x, self.prequant_scaling_factor.value)
|
|
|
|
x = cast(x, self.activation_scaling_factor.value.dtype)
|
|
|
|
quantized_out = quantize(x, self.activation_scaling_factor.value,
|
|
'int8')
|
|
|
|
dequantized_out = dequantize(quantized_out,
|
|
self.activation_scaling_factor.value, -1,
|
|
self.activation_scaling_factor.value.dtype)
|
|
|
|
dequantized_out = cast(dequantized_out, self.dtype)
|
|
|
|
w_deq_out = dequantize(self.weight.value,
|
|
self.weights_scaling_factor.value, 0,
|
|
self.weights_scaling_factor.value.dtype)
|
|
|
|
w_deq_out = cast(w_deq_out, self.dtype)
|
|
return self.multiply_collect(dequantized_out,
|
|
w_deq_out,
|
|
gemm_plugin=None,
|
|
all_reduce_params=all_reduce_params,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state)
|
|
|
|
|
|
class Int8SmoothQuantLinear(Linear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
gather_output=True,
|
|
prefer_managed_weight=True,
|
|
):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output,
|
|
prefer_managed_weight=prefer_managed_weight)
|
|
self.activation_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
|
|
self.weights_scaling_factor = Parameter(shape=(self.out_features, ),
|
|
dtype=trt.float32)
|
|
self.prequant_scaling_factor = Parameter(shape=(self.in_features, ),
|
|
dtype=dtype)
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=trt.int8,
|
|
prefer_managed=self.prefer_managed_weight)
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
lora_hidden_state = x if lora_runtime_params is not None else None
|
|
if default_net().strongly_typed:
|
|
assert is_same_dtype(
|
|
x.dtype,
|
|
self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}"
|
|
x = mul(x, self.prequant_scaling_factor.value)
|
|
x = cast(x, self.activation_scaling_factor.value.dtype)
|
|
|
|
quantized_out = quantize(x, self.activation_scaling_factor.value,
|
|
'int8')
|
|
|
|
dequantized_out = dequantize(quantized_out,
|
|
self.activation_scaling_factor.value, -1,
|
|
self.activation_scaling_factor.value.dtype)
|
|
|
|
dequantized_out = cast(dequantized_out, self.dtype)
|
|
|
|
w_deq_out = dequantize(self.weight.value,
|
|
self.weights_scaling_factor.value, 0,
|
|
self.weights_scaling_factor.value.dtype)
|
|
w_deq_out = cast(w_deq_out, self.dtype)
|
|
|
|
return self.multiply_collect(dequantized_out,
|
|
w_deq_out,
|
|
gemm_plugin=None,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state)
|
|
|
|
|
|
class FP8Linear(Linear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
gather_output=True,
|
|
prefer_managed_weight=True,
|
|
is_qkv=False):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output,
|
|
prefer_managed_weight=prefer_managed_weight,
|
|
is_qkv=is_qkv)
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype='fp8',
|
|
prefer_managed=self.prefer_managed_weight)
|
|
self.activation_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
self.weights_scaling_factor = Parameter(shape=(1, ), dtype=trt.float32)
|
|
if self.is_qkv:
|
|
self.tllm_to_externel_key_dict = {
|
|
"activation_scaling_factor": "input_scale",
|
|
"weight": ["weight", "weight_scale"],
|
|
"weights_scaling_factor": "",
|
|
}
|
|
else:
|
|
self.tllm_to_externel_key_dict = {
|
|
"activation_scaling_factor": "input_scale",
|
|
"weights_scaling_factor": "weight_scale",
|
|
}
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
assert lora_runtime_params is None or default_net(
|
|
).plugin_config.lora_plugin == self.dtype
|
|
|
|
if default_net().strongly_typed:
|
|
assert default_net().plugin_config.user_buffer or is_same_dtype(
|
|
x.dtype,
|
|
self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}"
|
|
|
|
alpha = self.weights_scaling_factor.raw_value * self.activation_scaling_factor.raw_value
|
|
activation_scaling_factor = constant(
|
|
self.activation_scaling_factor.raw_value)
|
|
activation_scaling_factor = cast(activation_scaling_factor, self.dtype)
|
|
if x.dtype != trt.fp8:
|
|
quantized_out = quantize(x, activation_scaling_factor, 'fp8')
|
|
lora_hidden_state = x if lora_runtime_params is not None else None
|
|
else:
|
|
quantized_out = x
|
|
# TODO: add fp8 LoRA support
|
|
lora_hidden_state = dequantize(
|
|
x, activation_scaling_factor, -1,
|
|
self.dtype) if lora_runtime_params is not None else None
|
|
|
|
weights_scaling_factor = cast(self.weights_scaling_factor.value,
|
|
self.dtype)
|
|
|
|
if self.weight.value.dtype != trt.fp8:
|
|
w_quant_out = quantize(self.weight.value, weights_scaling_factor,
|
|
'fp8')
|
|
else:
|
|
w_quant_out = self.weight.value
|
|
|
|
gemm_plugin = default_net().plugin_config.gemm_plugin
|
|
|
|
low_latency_gemm_plugin = default_net(
|
|
).plugin_config.low_latency_gemm_plugin
|
|
if (low_latency_gemm_plugin == "fp8"):
|
|
return self.multiply_collect(
|
|
quantized_out,
|
|
w_quant_out,
|
|
gemm_plugin=None,
|
|
low_latency_gemm_plugin=low_latency_gemm_plugin,
|
|
use_fp8=True,
|
|
alpha=alpha,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state)
|
|
elif gemm_plugin == 'fp8':
|
|
return self.multiply_collect(
|
|
quantized_out,
|
|
w_quant_out,
|
|
gemm_plugin=gemm_plugin,
|
|
use_fp8=True,
|
|
alpha=alpha,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state)
|
|
else:
|
|
dequantized_out = dequantize(quantized_out,
|
|
activation_scaling_factor, -1,
|
|
self.dtype)
|
|
w_deq_out = dequantize(w_quant_out, weights_scaling_factor, -1,
|
|
self.dtype)
|
|
return self.multiply_collect(
|
|
dequantized_out,
|
|
w_deq_out,
|
|
gemm_plugin=None,
|
|
use_fp8=True,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state)
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if self.is_qkv:
|
|
if tllm_key.endswith("activation_scaling_factor"):
|
|
return max(weights).reshape(1, ).to(torch.float32)
|
|
elif tllm_key.endswith("weights_scaling_factor"):
|
|
return {}
|
|
elif tllm_key.endswith("weight"):
|
|
assert len(weights) == 6
|
|
weight_scaling_factors = weights[1::2]
|
|
new_amax = max(weight_scaling_factors).reshape(1, ).to(
|
|
torch.float32)
|
|
for qkv_idx in range(3):
|
|
idx = qkv_idx * 2
|
|
weights[idx] = weights[idx].view(torch.float8_e4m3fn).to(
|
|
torch.float32)
|
|
weights[idx] *= weights[idx + 1]
|
|
weights[idx] /= new_amax
|
|
weights[idx] = weights[idx].to(torch.float8_e4m3fn)
|
|
weights = torch.cat(weights[::2])
|
|
scales = new_amax
|
|
return {
|
|
tllm_key: weights,
|
|
tllm_key.replace("weight", "weights_scaling_factor"): scales
|
|
}
|
|
else:
|
|
return super().postprocess(tllm_key, weights, **kwargs)
|
|
if tllm_key.endswith("scaling_factor"):
|
|
return weights.reshape(1, ).to(torch.float32)
|
|
elif tllm_key.endswith("weight"):
|
|
return weights.view(torch.float8_e4m3fn)
|
|
elif tllm_key.endswith("bias"):
|
|
return weights.to(trt_dtype_to_torch(self.bias.dtype))
|
|
|
|
|
|
class FP8RowLinear(RowLinear):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
prefer_managed_weight=True,
|
|
is_expert=False):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
prefer_managed_weight=prefer_managed_weight,
|
|
is_expert=is_expert)
|
|
self.weight = Parameter(
|
|
shape=(self.out_features, self.in_features),
|
|
dtype="fp8",
|
|
prefer_managed=self.prefer_managed_weight,
|
|
)
|
|
self.activation_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
self.weights_scaling_factor = Parameter(shape=(1, ), dtype=trt.float32)
|
|
self.tllm_to_externel_key_dict = {
|
|
"activation_scaling_factor": "input_scale",
|
|
"weights_scaling_factor": "weight_scale",
|
|
}
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
assert lora_runtime_params is None or default_net(
|
|
).plugin_config.lora_plugin == self.dtype
|
|
|
|
alpha = self.weights_scaling_factor.raw_value * self.activation_scaling_factor.raw_value
|
|
activation_scaling_factor = cast(self.activation_scaling_factor.value,
|
|
self.dtype)
|
|
if x.dtype != trt.fp8:
|
|
quantized_out = quantize(x, activation_scaling_factor, 'fp8')
|
|
lora_hidden_state = x if lora_runtime_params is not None else None
|
|
else:
|
|
quantized_out = x
|
|
# TODO: add fp8 LoRA support
|
|
lora_hidden_state = dequantize(
|
|
x, activation_scaling_factor, -1,
|
|
self.dtype) if lora_runtime_params is not None else None
|
|
|
|
weights_scaling_factor = cast(self.weights_scaling_factor.value,
|
|
self.dtype)
|
|
if self.weight.value.dtype != trt.fp8:
|
|
w_quant_out = quantize(self.weight.value, weights_scaling_factor,
|
|
'fp8')
|
|
else:
|
|
w_quant_out = self.weight.value
|
|
|
|
gemm_plugin = default_net().plugin_config.gemm_plugin
|
|
low_latency_gemm_plugin = default_net(
|
|
).plugin_config.low_latency_gemm_plugin
|
|
gemm_allreduce_plugin = default_net(
|
|
).plugin_config.gemm_allreduce_plugin
|
|
if gemm_allreduce_plugin:
|
|
ret = self.multiply_collect(quantized_out,
|
|
w_quant_out,
|
|
gemm_plugin=None,
|
|
use_fp8=True,
|
|
alpha=alpha,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state,
|
|
all_reduce_params=all_reduce_params)
|
|
elif (low_latency_gemm_plugin == "fp8"):
|
|
ret = self.multiply_collect(
|
|
quantized_out,
|
|
w_quant_out,
|
|
gemm_plugin=None,
|
|
low_latency_gemm_plugin=low_latency_gemm_plugin,
|
|
use_fp8=True,
|
|
alpha=alpha,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state,
|
|
all_reduce_params=all_reduce_params)
|
|
elif gemm_plugin == 'fp8':
|
|
ret = self.multiply_collect(quantized_out,
|
|
w_quant_out,
|
|
gemm_plugin=gemm_plugin,
|
|
use_fp8=True,
|
|
alpha=alpha,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state,
|
|
all_reduce_params=all_reduce_params)
|
|
else:
|
|
dequantized_out = dequantize(quantized_out,
|
|
activation_scaling_factor, -1,
|
|
self.dtype)
|
|
w_deq_out = dequantize(w_quant_out, weights_scaling_factor, -1,
|
|
self.dtype)
|
|
ret = self.multiply_collect(dequantized_out,
|
|
w_deq_out,
|
|
gemm_plugin=None,
|
|
use_fp8=True,
|
|
lora_runtime_params=lora_runtime_params,
|
|
lora_hidden_state=lora_hidden_state,
|
|
all_reduce_params=all_reduce_params)
|
|
return ret
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if tllm_key.endswith("scaling_factor"):
|
|
return weights.reshape(1, ).to(torch.float32)
|
|
elif tllm_key.endswith("weight"):
|
|
return weights.view(torch.float8_e4m3fn)
|
|
elif tllm_key.endswith("bias"):
|
|
return weights.to(str_dtype_to_torch(self.bias.dtype))
|
|
|
|
|
|
class Fp8RowwiseMLP(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
clamp_val=None,
|
|
):
|
|
super().__init__()
|
|
if hidden_act not in ACT2FN:
|
|
raise ValueError(
|
|
'unsupported activation function: {}'.format(hidden_act))
|
|
fc_output_size = 2 * ffn_hidden_size if hidden_act == 'swiglu' else ffn_hidden_size
|
|
self.fc = Fp8RowwiseColumnLinear(hidden_size,
|
|
fc_output_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
self.proj = Fp8RowwiseRowLinear(ffn_hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
self.hidden_size = hidden_size
|
|
self.ffn_hidden_size = ffn_hidden_size
|
|
self.hidden_act = hidden_act
|
|
self.bias = bias
|
|
self.dtype = dtype
|
|
self.tp_group = tp_group
|
|
self.tp_size = tp_size
|
|
self.quant_mode = quant_mode
|
|
|
|
if clamp_val:
|
|
if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
|
|
raise ValueError(f'unsupported clamp_val {clamp_val}')
|
|
self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
|
|
dtype='float32',
|
|
is_buffer=True)
|
|
else:
|
|
self.register_parameter('clamp_val', None)
|
|
|
|
def forward(self, hidden_states, lora_layer_params=None):
|
|
assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
|
|
inter = self.fc(hidden_states)
|
|
inter = ACT2FN[self.hidden_act](inter)
|
|
if self.quant_mode.has_fp8_rowwise():
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
inter = quantize_fp8_per_token(inter, clamp_val)
|
|
output = self.proj(inter)
|
|
return output
|
|
|
|
|
|
class Fp8RowwiseGatedMLP(Fp8RowwiseMLP):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
clamp_val=None,
|
|
):
|
|
super().__init__(hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode,
|
|
clamp_val=clamp_val)
|
|
if hidden_act not in ACT2FN:
|
|
raise ValueError(
|
|
'unsupported activation function: {}'.format(hidden_act))
|
|
self.gate = Fp8RowwiseColumnLinear(hidden_size,
|
|
ffn_hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
def forward(self, hidden_states, lora_layer_params=None):
|
|
assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
|
|
inter = self.fc(hidden_states)
|
|
inter = ACT2FN[self.hidden_act](inter)
|
|
gate = self.gate(hidden_states)
|
|
inter_x_gate = inter * gate
|
|
if self.quant_mode.has_fp8_rowwise():
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
inter_x_gate = quantize_fp8_per_token(inter_x_gate, clamp_val)
|
|
output = self.proj(inter_x_gate)
|
|
return output
|
|
|
|
|
|
class Fp8RowwiseFusedGatedMLP(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
clamp_val=None,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.ffn_hidden_size = ffn_hidden_size
|
|
self.hidden_act = hidden_act
|
|
self.bias = bias
|
|
self.dtype = dtype
|
|
self.tp_group = tp_group
|
|
self.tp_size = tp_size
|
|
self.quant_mode = quant_mode
|
|
|
|
self.fused_fc = Fp8RowwiseColumnLinear(hidden_size,
|
|
ffn_hidden_size * 2,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
self.proj = Fp8RowwiseRowLinear(ffn_hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
if clamp_val:
|
|
if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
|
|
raise ValueError(f'unsupported clamp_val {clamp_val}')
|
|
self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
|
|
dtype='float32',
|
|
is_buffer=True)
|
|
else:
|
|
self.register_parameter('clamp_val', None)
|
|
|
|
def forward(self, hidden_states, lora_layer_params=None):
|
|
assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
|
|
inter = self.fused_fc(hidden_states)
|
|
|
|
if self.hidden_act == 'silu':
|
|
inter = ACT2FN['swiglu'](inter)
|
|
elif self.hidden_act == 'gelu':
|
|
inter = ACT2FN['geglu'](inter)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Activation {self.hidden_act} not yet implemented for {self.__class__.__name__}."
|
|
)
|
|
|
|
if self.quant_mode.has_fp8_rowwise():
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
inter = quantize_fp8_per_token(inter, clamp_val)
|
|
output = self.proj(inter)
|
|
return output
|
|
|
|
|
|
class Fp8RowwiseAttention(Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
local_layer_idx,
|
|
hidden_size,
|
|
num_attention_heads,
|
|
num_kv_heads=None,
|
|
max_position_embeddings=1024,
|
|
num_layers=1,
|
|
apply_query_key_layer_scaling=False,
|
|
attention_head_size=None,
|
|
attention_mask_type=AttentionMaskType.padding,
|
|
bias=True,
|
|
dense_bias=None,
|
|
dtype=None,
|
|
position_embedding_type=PositionEmbeddingType.learned_absolute,
|
|
rotary_embedding_base=10000.0,
|
|
rotary_embedding_scaling=None,
|
|
rotary_embedding_percentage=1.0,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
scale_alibi_bias=False,
|
|
paged_kv_cache=False,
|
|
quant_mode=QuantMode(0),
|
|
clamp_val=None):
|
|
super().__init__()
|
|
self.local_layer_idx = local_layer_idx
|
|
self.attention_mask_type = attention_mask_type
|
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
|
self.num_attention_heads = num_attention_heads // tp_size
|
|
self.num_kv_heads = num_kv_heads
|
|
self.num_attention_kv_heads = (
|
|
num_kv_heads + tp_size - 1
|
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
|
self.hidden_size = hidden_size // tp_size
|
|
self.max_position_embeddings = 0 if max_position_embeddings is None else max_position_embeddings
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
self.dense_bias = dense_bias
|
|
if dense_bias is None:
|
|
self.dense_bias = bias
|
|
|
|
self.num_layers = num_layers
|
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
|
self.norm_factor = math.sqrt(self.attention_head_size)
|
|
self.q_scaling = 1
|
|
if self.apply_query_key_layer_scaling:
|
|
self.norm_factor *= self.num_layers
|
|
self.q_scaling *= self.num_layers
|
|
# Whether to scale ALiBi bias. Mathematically, it's equivalent to
|
|
# normalizing QK after adding bias.
|
|
# - False, inv_sqrt_Dh * Q*K^T + alibi_bias
|
|
# - True, inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias
|
|
self.scale_alibi_bias = scale_alibi_bias
|
|
|
|
self.position_embedding_type = position_embedding_type
|
|
self.paged_kv_cache = paged_kv_cache
|
|
|
|
self.rotary_embedding_base = rotary_embedding_base
|
|
self.rotary_embedding_scale_type = RotaryScalingType.none
|
|
self.rotary_embedding_scale = 1.0
|
|
self.rotary_embedding_dim = 0
|
|
|
|
if rotary_embedding_scaling is not None:
|
|
rotary_scaling_type = rotary_embedding_scaling.get(
|
|
"type", rotary_embedding_scaling.get("rope_type"))
|
|
self.rotary_embedding_scale_type = RotaryScalingType.from_string(
|
|
rotary_scaling_type)
|
|
self.rotary_embedding_scale = rotary_embedding_scaling.get(
|
|
"factor", 1.0)
|
|
|
|
if self.position_embedding_type.is_rope():
|
|
self.rotary_embedding_dim = int(self.attention_head_size *
|
|
rotary_embedding_percentage)
|
|
elif self.position_embedding_type.is_alibi():
|
|
alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1.
|
|
alibi_slopes = generate_alibi_slopes(self.num_attention_heads *
|
|
self.tp_size,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
alibi_scale=alibi_scale)
|
|
self.register_parameter(
|
|
'alibi_slopes',
|
|
Parameter(alibi_slopes, dtype='float32', is_buffer=True))
|
|
|
|
self.quant_mode = quant_mode
|
|
self.dtype = dtype
|
|
|
|
self.register_parameter('kv_cache_scaling_factor', None)
|
|
|
|
self.qkv = Fp8RowwiseColumnLinear(
|
|
hidden_size,
|
|
tp_size * self.num_attention_heads * self.attention_head_size +
|
|
(2 * tp_size * self.num_attention_kv_heads *
|
|
self.attention_head_size),
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
self.dense = Fp8RowwiseRowLinear(tp_size * self.num_attention_heads *
|
|
self.attention_head_size,
|
|
hidden_size,
|
|
bias=self.dense_bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
if clamp_val:
|
|
if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
|
|
raise ValueError(f'unsupported clamp_val {clamp_val}')
|
|
self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
|
|
dtype='float32',
|
|
is_buffer=True)
|
|
else:
|
|
self.register_parameter('clamp_val', None)
|
|
|
|
self.use_lora = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
use_cache=False,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
spec_decoding_params=None,
|
|
encoder_output=None,
|
|
position_embedding=None,
|
|
norm_before_bmm1=False,
|
|
lora_layer_params=None,
|
|
all_reduce_params: Optional[AllReduceParams] = None,
|
|
):
|
|
assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
|
|
qkv = self.qkv(hidden_states)
|
|
|
|
alibi_slopes = None
|
|
if self.position_embedding_type == PositionEmbeddingType.alibi:
|
|
alibi_slopes = self.alibi_slopes.value
|
|
dtype = trt.float32
|
|
if default_net().plugin_config.gpt_attention_plugin or default_net(
|
|
).plugin_config.inflight_batching_gpt_attention_plugin:
|
|
dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling(
|
|
) else hidden_states[0].dtype
|
|
if dtype == trt.int8:
|
|
dtype = trt.float16
|
|
alibi_slopes = cast(alibi_slopes, dtype)
|
|
|
|
if spec_decoding_params is None:
|
|
spec_decoding_params = SpecDecodingParams()
|
|
|
|
assert default_net().plugin_config.gpt_attention_plugin
|
|
|
|
assert attention_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin,
|
|
default_net().plugin_config.remove_input_padding, use_cache)
|
|
if use_cache:
|
|
assert kv_cache_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin)
|
|
assert self.attention_mask_type == AttentionMaskType.causal, \
|
|
'Plugin only support masked MHA.'
|
|
if self.kv_cache_scaling_factor is not None:
|
|
kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value
|
|
kv_quant_orig_scale = self.kv_cache_scaling_factor.value
|
|
else:
|
|
kv_orig_quant_scale = None
|
|
kv_quant_orig_scale = None
|
|
if self.position_embedding_type.is_rope():
|
|
rotary_inv_freq = attention_params.rotary_inv_freq
|
|
rotary_cos_sin = attention_params.embed_positions_for_gpt_attention
|
|
else:
|
|
rotary_inv_freq = None
|
|
rotary_cos_sin = None
|
|
context, past_key_value = gpt_attention(
|
|
qkv=qkv,
|
|
past_key_value=kv_cache_params.get_first_past_key_value(),
|
|
sequence_length=attention_params.sequence_length,
|
|
host_past_key_value_lengths=kv_cache_params.
|
|
host_past_key_value_lengths,
|
|
host_max_attention_window_sizes=kv_cache_params.
|
|
host_max_attention_window_sizes,
|
|
host_sink_token_length=kv_cache_params.host_sink_token_length,
|
|
context_lengths=attention_params.context_lengths,
|
|
cache_indirection=kv_cache_params.cache_indirection,
|
|
host_request_types=attention_params.host_request_types,
|
|
layer_idx=self.local_layer_idx,
|
|
num_heads=self.num_attention_heads,
|
|
num_kv_heads=self.num_attention_kv_heads,
|
|
num_kv_heads_origin=self.num_kv_heads,
|
|
hidden_size_per_head=self.attention_head_size,
|
|
q_scaling=self.q_scaling,
|
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
|
rotary_embedding_base=self.rotary_embedding_base,
|
|
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
|
|
rotary_embedding_scale=self.rotary_embedding_scale,
|
|
rotary_embedding_max_positions=self.max_position_embeddings,
|
|
position_embedding_type=self.position_embedding_type,
|
|
rotary_inv_freq=rotary_inv_freq,
|
|
rotary_cos_sin=rotary_cos_sin,
|
|
kv_orig_quant_scale=kv_orig_quant_scale,
|
|
kv_quant_orig_scale=kv_quant_orig_scale,
|
|
kv_cache_quant_mode=self.quant_mode,
|
|
max_context_length=attention_params.max_context_length,
|
|
alibi_slopes=alibi_slopes,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
kv_cache_block_offsets=kv_cache_params.kv_cache_block_offsets,
|
|
host_kv_cache_block_offsets=kv_cache_params.
|
|
host_kv_cache_block_offsets,
|
|
host_kv_cache_pool_pointers=kv_cache_params.
|
|
host_kv_cache_pool_pointers,
|
|
host_kv_cache_pool_mapping=kv_cache_params.
|
|
host_kv_cache_pool_mapping,
|
|
host_context_lengths=attention_params.host_context_lengths,
|
|
use_cache=use_cache,
|
|
spec_decoding_generation_lengths=spec_decoding_params.
|
|
spec_decoding_generation_lengths,
|
|
spec_decoding_position_offsets=spec_decoding_params.
|
|
spec_decoding_position_offsets,
|
|
spec_decoding_packed_mask=spec_decoding_params.
|
|
spec_decoding_packed_mask,
|
|
host_runtime_perf_knobs=attention_params.host_runtime_perf_knobs,
|
|
host_context_progress=attention_params.host_context_progress)
|
|
|
|
if self.quant_mode.has_fp8_rowwise():
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
context = quantize_fp8_per_token(context, clamp_val)
|
|
|
|
context = self.dense(
|
|
context,
|
|
all_reduce_params=all_reduce_params,
|
|
)
|
|
|
|
if use_cache:
|
|
return (context, past_key_value)
|
|
|
|
return context
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if tllm_key.endswith("kv_cache_scaling_factor") and weights is None:
|
|
return {tllm_key: torch.ones(1, )}
|
|
else:
|
|
return {tllm_key: weights}
|
|
|
|
|
|
class FP4Linear(Linear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
gather_output=True,
|
|
prefer_managed_weight=True,
|
|
is_qkv=False,
|
|
):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=gather_output,
|
|
prefer_managed_weight=prefer_managed_weight,
|
|
is_qkv=is_qkv)
|
|
self.scaling_vector_size = 16
|
|
assert self.in_features % self.scaling_vector_size == 0, \
|
|
"Input features must be a multiple of 16 for FP4 GEMM"
|
|
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=trt.fp4)
|
|
self.weights_block_scaling_factor = Parameter(
|
|
shape=(self.out_features,
|
|
self.in_features // self.scaling_vector_size),
|
|
dtype=trt.fp8)
|
|
self.weights_block_scaling_factor_interleaved = Parameter(
|
|
shape=(self.out_features,
|
|
self.in_features // self.scaling_vector_size),
|
|
dtype=trt.fp8)
|
|
self.weights_global_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
self.activation_global_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
# alpha = 1.0 / (weight_global_scale * act_global_scale)
|
|
self.alpha = Parameter(shape=(1, ), dtype=trt.float32)
|
|
if self.is_qkv:
|
|
self.tllm_to_externel_key_dict = {
|
|
"weight":
|
|
["weight", "weight_scale", "weight_scale_2", "input_scale"],
|
|
"weights_block_scaling_factor":
|
|
"",
|
|
"weights_block_scaling_factor_interleaved":
|
|
"",
|
|
"weights_global_scaling_factor":
|
|
"",
|
|
"activation_global_scaling_factor":
|
|
"",
|
|
"alpha":
|
|
"",
|
|
}
|
|
else:
|
|
self.tllm_to_externel_key_dict = {
|
|
"weight": ["weight"],
|
|
"weights_block_scaling_factor": "weight_scale",
|
|
"weights_block_scaling_factor_interleaved": "weight_scale",
|
|
"weights_global_scaling_factor": "weight_scale_2",
|
|
"activation_global_scaling_factor": "input_scale",
|
|
"alpha": ["weight_scale_2", "input_scale"],
|
|
}
|
|
|
|
def forward(self, x, lora_runtime_params=None):
|
|
assert lora_runtime_params is None, "lora is not supported on FP4Linear now"
|
|
if isinstance(x, (tuple, list)):
|
|
fp4_x, act_per_block_scale = x
|
|
else:
|
|
if default_net().plugin_config.gemm_plugin == 'nvfp4':
|
|
fp4_x, act_per_block_scale = quantize_to_fp4_tensor(
|
|
x, div(1, self.activation_global_scaling_factor.value))
|
|
else:
|
|
fp4_x, act_per_block_scale = dynamic_quantize(
|
|
x, self.activation_global_scaling_factor.value)
|
|
if default_net().plugin_config.gemm_plugin == 'nvfp4':
|
|
x = fp4_gemm(fp4_x, act_per_block_scale, self.weight.value,
|
|
self.weights_block_scaling_factor_interleaved.value,
|
|
self.alpha.value, self.dtype)
|
|
else:
|
|
quant_w = self.weight.value
|
|
scale_w = self.weights_block_scaling_factor.value
|
|
dequant_w = block_double_dequantize(
|
|
quant_w,
|
|
scale_w,
|
|
self.weights_global_scaling_factor.value,
|
|
dtype=trt.float16)
|
|
dequant_x = block_double_dequantize(
|
|
fp4_x,
|
|
act_per_block_scale,
|
|
self.activation_global_scaling_factor.value,
|
|
dtype=trt.float16)
|
|
x = matmul(dequant_x, dequant_w, transb=True).cast(self.dtype)
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
|
|
# [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
|
|
x = allgather(x, self.tp_group, gather_dim=1)
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if not any([
|
|
tllm_key.endswith(suffix)
|
|
for suffix in self.tllm_to_externel_key_dict
|
|
]):
|
|
return super().postprocess(tllm_key, weights, **kwargs)
|
|
if self.is_qkv:
|
|
if tllm_key.endswith("weight"):
|
|
assert len(weights) == 12
|
|
qkv_weight = weights[0::4]
|
|
qkv_block_scale = weights[1::4]
|
|
qkv_global_scale = weights[2::4]
|
|
qkv_input_scale = weights[3::4]
|
|
|
|
# Ckpt uses qkv max is guaranteed. So no need to re-quantize.
|
|
qkv_input_scale = max(qkv_input_scale).reshape(1, ).to(
|
|
torch.float32)
|
|
qkv_global_scale = max(qkv_global_scale).reshape(1, ).to(
|
|
torch.float32)
|
|
qkv_weight = torch.cat(qkv_weight)
|
|
qkv_block_scale = torch.cat(qkv_block_scale)
|
|
return {
|
|
tllm_key:
|
|
qkv_weight,
|
|
tllm_key.replace("weight", "weights_block_scaling_factor"):
|
|
qkv_block_scale,
|
|
tllm_key.replace(
|
|
'weight', "weights_block_scaling_factor_interleaved"):
|
|
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
|
qkv_block_scale.view(
|
|
torch.uint8).cpu().contiguous()).reshape(
|
|
qkv_block_scale.shape).view(
|
|
torch.float8_e4m3fn),
|
|
tllm_key.replace("weight", "weights_global_scaling_factor"):
|
|
qkv_global_scale,
|
|
tllm_key.replace("weight", "activation_global_scaling_factor"):
|
|
qkv_input_scale,
|
|
tllm_key.replace("weight", "alpha"):
|
|
qkv_input_scale * qkv_global_scale,
|
|
}
|
|
else:
|
|
return {}
|
|
else:
|
|
if tllm_key.endswith("weight"):
|
|
return weights
|
|
elif tllm_key.endswith("weights_block_scaling_factor"):
|
|
return weights
|
|
elif tllm_key.endswith("weights_block_scaling_factor_interleaved"):
|
|
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
|
weights.view(torch.uint8).cpu().contiguous()).reshape(
|
|
weights.shape).view(torch.float8_e4m3fn)
|
|
elif tllm_key.endswith("weights_global_scaling_factor"):
|
|
return weights.float()
|
|
elif tllm_key.endswith('activation_global_scaling_factor'):
|
|
return weights.float()
|
|
elif tllm_key.endswith('alpha'):
|
|
weight_global_sf = weights[0].float()
|
|
act_global_sf = weights[1].float()
|
|
return act_global_sf * weight_global_sf
|
|
|
|
|
|
class FP4RowLinear(RowLinear):
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
):
|
|
super().__init__(in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size)
|
|
self.scaling_vector_size = 16
|
|
assert self.in_features % self.scaling_vector_size == 0, \
|
|
"Input features must be a multiple of 16 for FP4 GEMM"
|
|
self.weight = Parameter(shape=(self.out_features, self.in_features),
|
|
dtype=trt.fp4)
|
|
self.weights_block_scaling_factor = Parameter(
|
|
shape=(self.out_features,
|
|
self.in_features // self.scaling_vector_size),
|
|
dtype=trt.fp8)
|
|
self.weights_block_scaling_factor_interleaved = Parameter(
|
|
shape=(self.out_features,
|
|
self.in_features // self.scaling_vector_size),
|
|
dtype=trt.fp8)
|
|
self.weights_global_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
self.activation_global_scaling_factor = Parameter(shape=(1, ),
|
|
dtype=trt.float32)
|
|
# alpha = 1.0 / (weight_global_scale * act_global_scale)
|
|
self.alpha = Parameter(shape=(1, ), dtype=trt.float32)
|
|
|
|
self.tllm_to_externel_key_dict = {
|
|
"weight": ["weight"],
|
|
"weights_block_scaling_factor": "weight_scale",
|
|
"weights_block_scaling_factor_interleaved": "weight_scale",
|
|
"weights_global_scaling_factor": "weight_scale_2",
|
|
"activation_global_scaling_factor": "input_scale",
|
|
"alpha": ["weight_scale_2", "input_scale"],
|
|
}
|
|
|
|
def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
|
|
assert lora_runtime_params is None, "lora is not supported on FP4Linear now"
|
|
|
|
if isinstance(x, (tuple, list)):
|
|
fp4_x, act_per_block_scale = x
|
|
else:
|
|
if default_net().plugin_config.gemm_plugin == "nvfp4":
|
|
fp4_x, act_per_block_scale = quantize_to_fp4_tensor(
|
|
x, div(1.0, self.activation_global_scaling_factor.value))
|
|
else:
|
|
# WAR for FP8 output attention
|
|
if x.dtype == trt.fp8:
|
|
# Since the scale is NVFP4 scale, we need to make it back to fp8 scale
|
|
new_scale_factor = self.activation_global_scaling_factor.raw_value
|
|
new_scale_factor = constant(new_scale_factor * 6)
|
|
x = dequantize(x, new_scale_factor, 0,
|
|
new_scale_factor.dtype)
|
|
fp4_x, act_per_block_scale = dynamic_quantize(
|
|
x, self.activation_global_scaling_factor.value)
|
|
|
|
if default_net().plugin_config.gemm_allreduce_plugin:
|
|
x = gemm_allreduce(
|
|
a=fp4_x,
|
|
b=self.weight.value,
|
|
a_sf=act_per_block_scale,
|
|
b_sf=self.weights_block_scaling_factor_interleaved.value,
|
|
transa=False, # row-major
|
|
transb=True, # col-major
|
|
alpha=self.alpha.value,
|
|
group=self.tp_group, # ranks participating
|
|
fp8_inputs_override=False)
|
|
else:
|
|
if default_net().plugin_config.gemm_plugin == "nvfp4":
|
|
x = fp4_gemm(
|
|
fp4_x, act_per_block_scale, self.weight.value,
|
|
self.weights_block_scaling_factor_interleaved.value,
|
|
self.alpha.value, self.dtype)
|
|
else:
|
|
quant_w = self.weight.value
|
|
scale_w = self.weights_block_scaling_factor.value
|
|
dequant_x = block_double_dequantize(
|
|
fp4_x, act_per_block_scale,
|
|
self.activation_global_scaling_factor.value, trt.float16)
|
|
dequant_w = block_double_dequantize(
|
|
quant_w, scale_w, self.weights_global_scaling_factor.value,
|
|
trt.float16)
|
|
x = matmul(dequant_x, dequant_w, transb=True).cast(self.dtype)
|
|
|
|
if self.tp_size > 1 and self.tp_group is not None:
|
|
need_bias = self.bias is not None
|
|
fuse_bias_into_all_reduce = need_bias and (
|
|
all_reduce_params
|
|
is not None) and (all_reduce_params.fusion_op
|
|
== AllReduceFusionOp.RESIDUAL_RMS_NORM)
|
|
if fuse_bias_into_all_reduce:
|
|
all_reduce_params.bias = self.bias.value
|
|
x = allreduce(x,
|
|
self.tp_group,
|
|
all_reduce_params=all_reduce_params)
|
|
if need_bias and not fuse_bias_into_all_reduce:
|
|
x = x + self.bias.value
|
|
return x
|
|
|
|
if self.bias is not None:
|
|
x = x + self.bias.value
|
|
|
|
return x
|
|
|
|
def postprocess(self, tllm_key, weights, **kwargs):
|
|
if not any([
|
|
tllm_key.endswith(suffix)
|
|
for suffix in self.tllm_to_externel_key_dict
|
|
]):
|
|
return super().postprocess(tllm_key, weights, **kwargs)
|
|
|
|
if tllm_key.endswith("weight"):
|
|
return weights
|
|
elif tllm_key.endswith("weights_block_scaling_factor"):
|
|
return weights
|
|
elif tllm_key.endswith("weights_block_scaling_factor_interleaved"):
|
|
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
|
weights.view(torch.uint8).cpu().contiguous()).reshape(
|
|
weights.shape).view(torch.float8_e4m3fn)
|
|
elif tllm_key.endswith("weights_global_scaling_factor"):
|
|
return weights.float()
|
|
elif tllm_key.endswith('activation_global_scaling_factor'):
|
|
return weights.float()
|
|
elif tllm_key.endswith('alpha'):
|
|
weight_global_sf = weights[0].float()
|
|
act_global_sf = weights[1].float()
|
|
return act_global_sf * weight_global_sf
|
|
|
|
|
|
class SmoothQuantGatedMLP(SmoothQuantMLP):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
super().__init__(hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
if hidden_act not in ACT2FN:
|
|
raise ValueError(
|
|
'unsupported activation function: {}'.format(hidden_act))
|
|
self.gate = SmoothQuantColumnLinear(hidden_size,
|
|
ffn_hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
if self.quant_mode.has_act_static_scaling():
|
|
self.quantization_scaling_factor = Parameter(shape=(1, ),
|
|
dtype='float32')
|
|
else:
|
|
self.register_parameter('quantization_scaling_factor', None)
|
|
|
|
def forward(self, hidden_states, lora_layer_params=None):
|
|
assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
|
|
inter = self.fc(hidden_states)
|
|
inter = ACT2FN[self.hidden_act](inter)
|
|
gate = self.gate(hidden_states)
|
|
inter_x_gate = inter * gate
|
|
smoother = cast(self.proj.smoother.value, self.dtype)
|
|
inter_x_gate = inter_x_gate / smoother
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
if self.quant_mode.has_act_static_scaling():
|
|
# Avoid quantization layers as it breaks int8 plugins
|
|
inter_x_gate = quantize_tensor(
|
|
inter_x_gate, self.quantization_scaling_factor.value)
|
|
else:
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
inter_x_gate = quantize_per_token(inter_x_gate)
|
|
|
|
output = self.proj(inter_x_gate)
|
|
return output
|
|
|
|
|
|
class SmoothQuantAttention(Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
local_layer_idx,
|
|
hidden_size,
|
|
num_attention_heads,
|
|
num_kv_heads=None,
|
|
max_position_embeddings=1024,
|
|
num_layers=1,
|
|
apply_query_key_layer_scaling=False,
|
|
attention_head_size=None,
|
|
attention_mask_type=AttentionMaskType.padding,
|
|
bias=True,
|
|
dense_bias=None,
|
|
dtype=None,
|
|
position_embedding_type=PositionEmbeddingType.learned_absolute,
|
|
rotary_embedding_base=10000.0,
|
|
rotary_embedding_scaling=None,
|
|
rotary_embedding_percentage=1.0,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
scale_alibi_bias=False,
|
|
paged_kv_cache=False,
|
|
quant_mode=QuantMode(0)):
|
|
super().__init__()
|
|
self.local_layer_idx = local_layer_idx
|
|
self.attention_mask_type = attention_mask_type
|
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
|
self.num_attention_heads = num_attention_heads // tp_size
|
|
self.num_kv_heads = num_kv_heads
|
|
self.num_attention_kv_heads = (
|
|
num_kv_heads + tp_size - 1
|
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
|
self.hidden_size = hidden_size // tp_size
|
|
self.max_position_embeddings = 0 if max_position_embeddings is None else max_position_embeddings
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
self.dense_bias = dense_bias
|
|
if dense_bias is None:
|
|
self.dense_bias = bias
|
|
|
|
self.num_layers = num_layers
|
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
|
self.norm_factor = math.sqrt(self.attention_head_size)
|
|
self.q_scaling = 1
|
|
if self.apply_query_key_layer_scaling:
|
|
self.norm_factor *= self.num_layers
|
|
self.q_scaling *= self.num_layers
|
|
# Whether to scale ALiBi bias. Mathematically, it's equivalent to
|
|
# normalizing QK after adding bias.
|
|
# - False, inv_sqrt_Dh * Q*K^T + alibi_bias
|
|
# - True, inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias
|
|
self.scale_alibi_bias = scale_alibi_bias
|
|
|
|
self.position_embedding_type = position_embedding_type
|
|
self.paged_kv_cache = paged_kv_cache
|
|
|
|
self.rotary_embedding_base = rotary_embedding_base
|
|
self.rotary_embedding_scale_type = RotaryScalingType.none
|
|
self.rotary_embedding_scale = 1.0
|
|
self.rotary_embedding_dim = 0
|
|
|
|
if rotary_embedding_scaling is not None:
|
|
rotary_scaling_type = rotary_embedding_scaling.get(
|
|
"type", rotary_embedding_scaling.get("rope_type"))
|
|
self.rotary_embedding_scale_type = RotaryScalingType.from_string(
|
|
rotary_scaling_type)
|
|
self.rotary_embedding_scale = rotary_embedding_scaling.get(
|
|
"factor", 1.0)
|
|
|
|
if self.position_embedding_type.is_rope():
|
|
self.rotary_embedding_dim = int(self.attention_head_size *
|
|
rotary_embedding_percentage)
|
|
elif self.position_embedding_type.is_alibi():
|
|
alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1.
|
|
alibi_slopes = generate_alibi_slopes(self.num_attention_heads *
|
|
self.tp_size,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
alibi_scale=alibi_scale)
|
|
self.register_parameter(
|
|
'alibi_slopes',
|
|
Parameter(alibi_slopes, dtype='float32', is_buffer=True))
|
|
|
|
self.quant_mode = quant_mode
|
|
self.dtype = dtype
|
|
|
|
if self.quant_mode.has_act_static_scaling():
|
|
self.quantization_scaling_factor = Parameter(shape=(1, ),
|
|
dtype='float32')
|
|
else:
|
|
self.register_parameter('quantization_scaling_factor', None)
|
|
|
|
qkv_quant_mode = quant_mode
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
# We need to hijack quant_mode for QKV because QKV always uses per channel scaling
|
|
qkv_quant_mode = QuantMode.from_description(
|
|
True, True, quant_mode.has_per_token_dynamic_scaling(), True)
|
|
|
|
self.register_parameter('kv_cache_scaling_factor', None)
|
|
|
|
self.qkv = SmoothQuantColumnLinear(
|
|
hidden_size,
|
|
tp_size * self.num_attention_heads * self.attention_head_size +
|
|
(2 * tp_size * self.num_attention_kv_heads *
|
|
self.attention_head_size),
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=qkv_quant_mode)
|
|
|
|
self.dense = SmoothQuantRowLinear(tp_size * self.num_attention_heads *
|
|
self.attention_head_size,
|
|
hidden_size,
|
|
bias=self.dense_bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
self.use_lora = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
use_cache=False,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
spec_decoding_params=None,
|
|
mrope_params=None,
|
|
encoder_output=None,
|
|
position_embedding=None,
|
|
norm_before_bmm1=False,
|
|
lora_layer_params=None,
|
|
all_reduce_params: Optional[AllReduceParams] = None,
|
|
):
|
|
assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
|
|
qkv = self.qkv(hidden_states)
|
|
|
|
alibi_slopes = None
|
|
if self.position_embedding_type == PositionEmbeddingType.alibi:
|
|
alibi_slopes = self.alibi_slopes.value
|
|
dtype = trt.float32
|
|
if default_net().plugin_config.gpt_attention_plugin or default_net(
|
|
).plugin_config.inflight_batching_gpt_attention_plugin:
|
|
dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling(
|
|
) else hidden_states[0].dtype
|
|
if dtype == trt.int8:
|
|
dtype = trt.float16
|
|
alibi_slopes = cast(alibi_slopes, dtype)
|
|
|
|
if spec_decoding_params is None:
|
|
spec_decoding_params = SpecDecodingParams()
|
|
|
|
if mrope_params is None:
|
|
mrope_params = MropeParams()
|
|
|
|
if default_net().plugin_config.gpt_attention_plugin:
|
|
|
|
assert attention_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin,
|
|
default_net().plugin_config.remove_input_padding, use_cache)
|
|
if use_cache:
|
|
assert kv_cache_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin)
|
|
assert self.attention_mask_type == AttentionMaskType.causal, \
|
|
'Plugin only support masked MHA.'
|
|
if self.kv_cache_scaling_factor is not None:
|
|
kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value
|
|
kv_quant_orig_scale = self.kv_cache_scaling_factor.value
|
|
else:
|
|
kv_orig_quant_scale = None
|
|
kv_quant_orig_scale = None
|
|
if self.position_embedding_type.is_rope():
|
|
rotary_inv_freq = attention_params.rotary_inv_freq
|
|
rotary_cos_sin = attention_params.embed_positions_for_gpt_attention
|
|
else:
|
|
rotary_inv_freq = None
|
|
rotary_cos_sin = None
|
|
context, past_key_value = gpt_attention(
|
|
qkv=qkv,
|
|
past_key_value=kv_cache_params.get_first_past_key_value(),
|
|
sequence_length=attention_params.sequence_length,
|
|
host_past_key_value_lengths=kv_cache_params.
|
|
host_past_key_value_lengths,
|
|
host_max_attention_window_sizes=kv_cache_params.
|
|
host_max_attention_window_sizes,
|
|
host_sink_token_length=kv_cache_params.host_sink_token_length,
|
|
context_lengths=attention_params.context_lengths,
|
|
cache_indirection=kv_cache_params.cache_indirection,
|
|
host_request_types=attention_params.host_request_types,
|
|
layer_idx=self.local_layer_idx,
|
|
num_heads=self.num_attention_heads,
|
|
num_kv_heads=self.num_attention_kv_heads,
|
|
num_kv_heads_origin=self.num_kv_heads,
|
|
hidden_size_per_head=self.attention_head_size,
|
|
q_scaling=self.q_scaling,
|
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
|
rotary_embedding_base=self.rotary_embedding_base,
|
|
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
|
|
rotary_embedding_scale=self.rotary_embedding_scale,
|
|
rotary_embedding_max_positions=self.max_position_embeddings,
|
|
position_embedding_type=self.position_embedding_type,
|
|
rotary_inv_freq=rotary_inv_freq,
|
|
rotary_cos_sin=rotary_cos_sin,
|
|
kv_orig_quant_scale=kv_orig_quant_scale,
|
|
kv_quant_orig_scale=kv_quant_orig_scale,
|
|
kv_cache_quant_mode=self.quant_mode,
|
|
max_context_length=attention_params.max_context_length,
|
|
alibi_slopes=alibi_slopes,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
kv_cache_block_offsets=kv_cache_params.kv_cache_block_offsets,
|
|
host_kv_cache_block_offsets=kv_cache_params.
|
|
host_kv_cache_block_offsets,
|
|
host_kv_cache_pool_pointers=kv_cache_params.
|
|
host_kv_cache_pool_pointers,
|
|
host_kv_cache_pool_mapping=kv_cache_params.
|
|
host_kv_cache_pool_mapping,
|
|
host_context_lengths=attention_params.host_context_lengths,
|
|
use_cache=use_cache,
|
|
spec_decoding_generation_lengths=spec_decoding_params.
|
|
spec_decoding_generation_lengths,
|
|
spec_decoding_position_offsets=spec_decoding_params.
|
|
spec_decoding_position_offsets,
|
|
spec_decoding_packed_mask=spec_decoding_params.
|
|
spec_decoding_packed_mask,
|
|
mrope_rotary_cos_sin=mrope_params.mrope_rotary_cos_sin,
|
|
mrope_position_deltas=mrope_params.mrope_position_deltas,
|
|
host_runtime_perf_knobs=attention_params.
|
|
host_runtime_perf_knobs,
|
|
host_context_progress=attention_params.host_context_progress,
|
|
)
|
|
else:
|
|
assert self.paged_kv_cache == False
|
|
|
|
def transpose_for_scores(x):
|
|
new_x_shape = concat([
|
|
shape(x, 0),
|
|
shape(x, 1), self.num_attention_heads,
|
|
self.attention_head_size
|
|
])
|
|
return x.view(new_x_shape).permute([0, 2, 1, 3])
|
|
|
|
query, key, value = split(qkv, self.hidden_size, dim=2)
|
|
query = transpose_for_scores(query)
|
|
key = transpose_for_scores(key)
|
|
value = transpose_for_scores(value)
|
|
|
|
past_key_value = kv_cache_params.get_first_past_key_value()
|
|
if past_key_value is not None:
|
|
|
|
def dequantize_tensor(x, scale):
|
|
# Cast from int8 to dtype
|
|
casted_x = cast(x, self.dtype)
|
|
return casted_x * scale
|
|
|
|
if self.quant_mode.has_int8_kv_cache():
|
|
past_key_value = dequantize_tensor(
|
|
past_key_value, self.kv_dequantization_scale.value)
|
|
|
|
# past_key_value [bs, 2, num_heads, max_seq_len, head_dim]
|
|
past_key, past_value = split(past_key_value, 1, dim=1)
|
|
|
|
key_shape = concat([
|
|
shape(past_key, 0),
|
|
shape(past_key, 2),
|
|
shape(past_key, 3),
|
|
shape(past_key, 4)
|
|
])
|
|
past_key = past_key.view(key_shape, zero_is_placeholder=False)
|
|
past_value = past_value.view(key_shape,
|
|
zero_is_placeholder=False)
|
|
key = concat([past_key, key], dim=2)
|
|
value = concat([past_value, value], dim=2)
|
|
|
|
def merge_caches():
|
|
key_inflated_shape = concat([
|
|
shape(key, 0), 1,
|
|
shape(key, 1),
|
|
shape(key, 2),
|
|
shape(key, 3)
|
|
])
|
|
inflated_key = key.view(key_inflated_shape,
|
|
zero_is_placeholder=False)
|
|
inflated_value = value.view(key_inflated_shape,
|
|
zero_is_placeholder=False)
|
|
past_key_value = concat([inflated_key, inflated_value], dim=1)
|
|
return past_key_value
|
|
|
|
if self.attention_mask_type == AttentionMaskType.causal:
|
|
query_length = shape(query, 2)
|
|
key_length = shape(key, 2)
|
|
starts = concat([0, 0, key_length - query_length, 0])
|
|
sizes = concat([1, 1, query_length, key_length])
|
|
buffer = constant(
|
|
np.expand_dims(
|
|
np.tril(
|
|
np.ones(
|
|
(self.max_position_embeddings,
|
|
self.max_position_embeddings))).astype(bool),
|
|
(0, 1)))
|
|
causal_mask = slice(buffer, starts, sizes)
|
|
|
|
key = key.permute([0, 1, 3, 2])
|
|
with precision("float32"):
|
|
attention_scores = matmul(query, key)
|
|
|
|
if self.attention_mask_type == AttentionMaskType.causal:
|
|
attention_scores = where(causal_mask, attention_scores,
|
|
-10000.0)
|
|
|
|
attention_scores = attention_scores / self.norm_factor
|
|
attention_probs = softmax(attention_scores, dim=-1)
|
|
|
|
context = matmul(attention_probs, value,
|
|
use_fp32_acc=False).permute([0, 2, 1, 3])
|
|
context = context.view(
|
|
concat([shape(context, 0),
|
|
shape(context, 1), self.hidden_size]))
|
|
|
|
past_key_value = merge_caches()
|
|
|
|
if use_cache and self.quant_mode.has_int8_kv_cache():
|
|
past_key_value = quantize_tensor(
|
|
past_key_value, self.kv_quantization_scale.value)
|
|
value = cast(self.dense.smoother.value, context.dtype)
|
|
context = context / value
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
if self.quant_mode.has_act_static_scaling():
|
|
# Avoid quantization layers as it breaks int8 plugins
|
|
context = quantize_tensor(
|
|
context, self.quantization_scaling_factor.value)
|
|
else:
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
context = quantize_per_token(context)
|
|
|
|
context = self.dense(
|
|
context,
|
|
all_reduce_params=all_reduce_params,
|
|
)
|
|
|
|
if use_cache:
|
|
return (context, past_key_value)
|
|
|
|
return context
|
|
|
|
|
|
# TODO: Duplicates SmoothQuantRmsNorm
|
|
class QServeRmsNorm(Module):
|
|
|
|
def __init__(self,
|
|
normalized_shape,
|
|
eps=1e-06,
|
|
elementwise_affine=False,
|
|
dtype=None,
|
|
quant_mode=QuantMode(0),
|
|
bias=False,
|
|
clamp_val=None):
|
|
super().__init__()
|
|
if isinstance(normalized_shape, int):
|
|
normalized_shape = (normalized_shape, )
|
|
assert quant_mode.is_qserve_w4a8()
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
|
|
if bias:
|
|
self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
if clamp_val:
|
|
if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
|
|
raise ValueError(f'unsupported clamp_val {clamp_val}')
|
|
self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
|
|
dtype='float32',
|
|
is_buffer=True)
|
|
else:
|
|
self.register_parameter('clamp_val', None)
|
|
|
|
self.eps = eps
|
|
self.dtype = dtype
|
|
self.quant_mode = quant_mode
|
|
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
self.scale_to_int = Parameter(shape=(1, ), dtype=dtype)
|
|
else:
|
|
self.register_parameter('scale_to_int', None)
|
|
|
|
def forward(self, x):
|
|
weight = None if self.weight is None else self.weight.value
|
|
bias = None if self.bias is None else self.bias.value
|
|
scale = None if self.scale_to_int is None else self.scale_to_int.value
|
|
clamp_val = None if self.clamp_val is None else self.clamp_val.value
|
|
return smooth_quant_rms_norm(
|
|
x,
|
|
self.normalized_shape,
|
|
weight,
|
|
bias,
|
|
scale,
|
|
clamp_val,
|
|
self.eps,
|
|
dynamic_act_scaling=True,
|
|
scale_dtype='float16',
|
|
sum_per_token=not self.quant_mode.has_per_group_scaling(),
|
|
sum_dtype='float16')
|
|
|
|
|
|
# TODO: Mostly duplicates SmoothQuantMLP.
|
|
# TODO: MLP could represent GatedMLP if hidden_act=='swiglu'.
|
|
class QServeMLP(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
super().__init__()
|
|
if hidden_act not in ACT2FN:
|
|
raise ValueError(
|
|
'unsupported activation function: {}'.format(hidden_act))
|
|
fc_output_size = 2 * ffn_hidden_size if hidden_act == 'swiglu' else ffn_hidden_size
|
|
self.fc = QServeW4A8ColumnLinear(hidden_size,
|
|
fc_output_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
self.proj = QServeW4A8RowLinear(ffn_hidden_size,
|
|
hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
self.hidden_act = hidden_act
|
|
self.quant_mode = quant_mode
|
|
self.dtype = dtype
|
|
|
|
def forward(self, hidden_states):
|
|
inter = self.fc(hidden_states)
|
|
inter = ACT2FN[self.hidden_act](inter)
|
|
inter = quantize_per_token(
|
|
inter,
|
|
scale_dtype='float16',
|
|
sum_per_token=not self.quant_mode.has_per_group_scaling(),
|
|
sum_dtype='float16')
|
|
output = self.proj(inter)
|
|
return output
|
|
|
|
|
|
# TODO: Mostly duplicates SmoothQuantGatedMLP.
|
|
class QServeGatedMLP(QServeMLP):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=True,
|
|
dtype=None,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
quant_mode=QuantMode(0),
|
|
):
|
|
super().__init__(hidden_size,
|
|
ffn_hidden_size,
|
|
hidden_act,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
if hidden_act not in ACT2FN:
|
|
raise ValueError(
|
|
'unsupported activation function: {}'.format(hidden_act))
|
|
self.gate = QServeW4A8Linear(hidden_size,
|
|
ffn_hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=quant_mode)
|
|
|
|
def forward(self, hidden_states, lora_layer_params=None):
|
|
assert lora_layer_params is None, "lora_layer_params not supported"
|
|
inter = self.fc(hidden_states)
|
|
inter = ACT2FN[self.hidden_act](inter)
|
|
gate = self.gate(hidden_states)
|
|
|
|
inter_x_gate = inter * gate
|
|
inter_x_gate = quantize_per_token(
|
|
inter_x_gate,
|
|
scale_dtype='float16',
|
|
sum_per_token=not self.quant_mode.has_per_group_scaling(),
|
|
sum_dtype='float16')
|
|
|
|
output = self.proj(inter_x_gate)
|
|
return output
|
|
|
|
|
|
# TODO: Duplicates SmoothQuantAttention.
|
|
class QServeAttention(Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
local_layer_idx,
|
|
hidden_size,
|
|
num_attention_heads,
|
|
num_kv_heads=None,
|
|
max_position_embeddings=1024,
|
|
num_layers=1,
|
|
apply_query_key_layer_scaling=False,
|
|
attention_head_size=None,
|
|
attention_mask_type=AttentionMaskType.padding,
|
|
bias=True,
|
|
dense_bias=None,
|
|
dtype=None,
|
|
position_embedding_type=PositionEmbeddingType.learned_absolute,
|
|
rotary_embedding_base=10000.0,
|
|
rotary_embedding_scaling=None,
|
|
rotary_embedding_percentage=1.0,
|
|
tp_group=None,
|
|
tp_size=1,
|
|
tp_rank=0,
|
|
scale_alibi_bias=False,
|
|
paged_kv_cache=False,
|
|
quant_mode=QuantMode(0)):
|
|
super().__init__()
|
|
self.local_layer_idx = local_layer_idx
|
|
self.attention_mask_type = attention_mask_type
|
|
self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
|
|
self.num_attention_heads = num_attention_heads // tp_size
|
|
self.num_kv_heads = num_kv_heads
|
|
self.num_attention_kv_heads = (
|
|
num_kv_heads + tp_size - 1
|
|
) // tp_size if num_kv_heads is not None else self.num_attention_heads
|
|
self.hidden_size = hidden_size // tp_size
|
|
self.max_position_embeddings = 0 if max_position_embeddings is None else max_position_embeddings
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
self.dense_bias = dense_bias
|
|
if dense_bias is None:
|
|
self.dense_bias = bias
|
|
|
|
self.num_layers = num_layers
|
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
|
self.norm_factor = math.sqrt(self.attention_head_size)
|
|
self.q_scaling = 1
|
|
if self.apply_query_key_layer_scaling:
|
|
self.norm_factor *= self.num_layers
|
|
self.q_scaling *= self.num_layers
|
|
# Whether to scale ALiBi bias. Mathematically, it's equivalent to
|
|
# normalizing QK after adding bias.
|
|
# - False, inv_sqrt_Dh * Q*K^T + alibi_bias
|
|
# - True, inv_sqrt_Dh * Q*K^T + inv_sqrt_Dh * alibi_bias
|
|
self.scale_alibi_bias = scale_alibi_bias
|
|
|
|
self.position_embedding_type = position_embedding_type
|
|
self.paged_kv_cache = paged_kv_cache
|
|
|
|
self.rotary_embedding_base = rotary_embedding_base
|
|
self.rotary_embedding_scale_type = RotaryScalingType.none
|
|
self.rotary_embedding_scale = 1.0
|
|
self.rotary_embedding_dim = 0
|
|
|
|
if rotary_embedding_scaling is not None:
|
|
rotary_scaling_type = rotary_embedding_scaling.get(
|
|
"type", rotary_embedding_scaling.get("rope_type"))
|
|
self.rotary_embedding_scale_type = RotaryScalingType.from_string(
|
|
rotary_scaling_type)
|
|
self.rotary_embedding_scale = rotary_embedding_scaling.get(
|
|
"factor", 1.0)
|
|
|
|
if self.position_embedding_type.is_rope():
|
|
self.rotary_embedding_dim = int(self.attention_head_size *
|
|
rotary_embedding_percentage)
|
|
elif self.position_embedding_type.is_alibi():
|
|
alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1.
|
|
alibi_slopes = generate_alibi_slopes(self.num_attention_heads *
|
|
self.tp_size,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
alibi_scale=alibi_scale)
|
|
self.register_parameter(
|
|
'alibi_slopes',
|
|
Parameter(alibi_slopes, dtype='float32', is_buffer=True))
|
|
|
|
self.quant_mode = quant_mode
|
|
self.dtype = dtype
|
|
|
|
# QServe does not use act static scaling
|
|
# if self.quant_mode.has_act_static_scaling():
|
|
# self.quantization_scaling_factor = Parameter(shape=(1, ),
|
|
# dtype='float32')
|
|
# else:
|
|
# self.register_parameter('quantization_scaling_factor', None)
|
|
|
|
qkv_quant_mode = quant_mode
|
|
if self.quant_mode.has_act_and_weight_quant():
|
|
# We need to hijack quant_mode for QKV because QKV always uses per channel scaling
|
|
qkv_quant_mode = QuantMode.from_description(
|
|
True, True, quant_mode.has_per_token_dynamic_scaling(), True)
|
|
|
|
self.register_parameter('kv_cache_scaling_factor', None)
|
|
|
|
self.qkv = QServeW4A8ColumnLinear(
|
|
hidden_size,
|
|
tp_size * self.num_attention_heads * self.attention_head_size +
|
|
(2 * tp_size * self.num_attention_kv_heads *
|
|
self.attention_head_size),
|
|
bias=bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
gather_output=False,
|
|
quant_mode=qkv_quant_mode)
|
|
|
|
self.dense = QServeW4A8RowLinear(tp_size * self.num_attention_heads *
|
|
self.attention_head_size,
|
|
hidden_size,
|
|
bias=self.dense_bias,
|
|
dtype=dtype,
|
|
tp_group=tp_group,
|
|
tp_size=tp_size,
|
|
quant_mode=quant_mode)
|
|
|
|
self.use_lora = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: Tensor,
|
|
attention_mask=None,
|
|
use_cache=False,
|
|
kv_cache_params=None,
|
|
attention_params=None,
|
|
spec_decoding_params=None,
|
|
encoder_output=None,
|
|
position_embedding=None,
|
|
norm_before_bmm1=False,
|
|
lora_layer_params=None,
|
|
all_reduce_params: Optional[AllReduceParams] = None,
|
|
):
|
|
assert lora_layer_params is None, "lora is not supported on SmoothQuantAttention now"
|
|
if default_net().plugin_config.qserve_gemm_plugin:
|
|
qkv = self.qkv(hidden_states)
|
|
else:
|
|
raise ValueError("qserve_gemm_plugin is not set")
|
|
|
|
alibi_slopes = None
|
|
if self.position_embedding_type == PositionEmbeddingType.alibi:
|
|
alibi_slopes = self.alibi_slopes.value
|
|
dtype = trt.float32
|
|
if default_net().plugin_config.gpt_attention_plugin or default_net(
|
|
).plugin_config.inflight_batching_gpt_attention_plugin:
|
|
dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling(
|
|
) else hidden_states[0].dtype
|
|
if dtype == trt.int8:
|
|
dtype = trt.float16
|
|
alibi_slopes = cast(alibi_slopes, dtype)
|
|
|
|
if spec_decoding_params is None:
|
|
spec_decoding_params = SpecDecodingParams()
|
|
|
|
if default_net().plugin_config.gpt_attention_plugin:
|
|
|
|
assert attention_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin,
|
|
default_net().plugin_config.remove_input_padding, use_cache)
|
|
if use_cache:
|
|
assert kv_cache_params.is_valid(
|
|
default_net().plugin_config.gpt_attention_plugin)
|
|
assert self.attention_mask_type == AttentionMaskType.causal, \
|
|
'Plugin only support masked MHA.'
|
|
if self.kv_cache_scaling_factor is not None:
|
|
kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value
|
|
kv_quant_orig_scale = self.kv_cache_scaling_factor.value
|
|
else:
|
|
kv_orig_quant_scale = None
|
|
kv_quant_orig_scale = None
|
|
if self.position_embedding_type.is_rope():
|
|
rotary_inv_freq = attention_params.rotary_inv_freq
|
|
rotary_cos_sin = attention_params.embed_positions_for_gpt_attention
|
|
else:
|
|
rotary_inv_freq = None
|
|
rotary_cos_sin = None
|
|
context, past_key_value = gpt_attention(
|
|
qkv=qkv,
|
|
past_key_value=kv_cache_params.get_first_past_key_value(),
|
|
sequence_length=attention_params.sequence_length,
|
|
host_past_key_value_lengths=kv_cache_params.
|
|
host_past_key_value_lengths,
|
|
host_max_attention_window_sizes=kv_cache_params.
|
|
host_max_attention_window_sizes,
|
|
host_sink_token_length=kv_cache_params.host_sink_token_length,
|
|
context_lengths=attention_params.context_lengths,
|
|
cache_indirection=kv_cache_params.cache_indirection,
|
|
host_request_types=attention_params.host_request_types,
|
|
layer_idx=self.local_layer_idx,
|
|
num_heads=self.num_attention_heads,
|
|
num_kv_heads=self.num_attention_kv_heads,
|
|
num_kv_heads_origin=self.num_kv_heads,
|
|
hidden_size_per_head=self.attention_head_size,
|
|
q_scaling=self.q_scaling,
|
|
rotary_embedding_dim=self.rotary_embedding_dim,
|
|
rotary_embedding_base=self.rotary_embedding_base,
|
|
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
|
|
rotary_embedding_scale=self.rotary_embedding_scale,
|
|
rotary_embedding_max_positions=self.max_position_embeddings,
|
|
position_embedding_type=self.position_embedding_type,
|
|
rotary_inv_freq=rotary_inv_freq,
|
|
rotary_cos_sin=rotary_cos_sin,
|
|
kv_orig_quant_scale=kv_orig_quant_scale,
|
|
kv_quant_orig_scale=kv_quant_orig_scale,
|
|
kv_cache_quant_mode=self.quant_mode,
|
|
max_context_length=attention_params.max_context_length,
|
|
alibi_slopes=alibi_slopes,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
kv_cache_block_offsets=kv_cache_params.kv_cache_block_offsets,
|
|
host_kv_cache_block_offsets=kv_cache_params.
|
|
host_kv_cache_block_offsets,
|
|
host_kv_cache_pool_pointers=kv_cache_params.
|
|
host_kv_cache_pool_pointers,
|
|
host_kv_cache_pool_mapping=kv_cache_params.
|
|
host_kv_cache_pool_mapping,
|
|
host_context_lengths=attention_params.host_context_lengths,
|
|
use_cache=use_cache,
|
|
spec_decoding_generation_lengths=spec_decoding_params.
|
|
spec_decoding_generation_lengths,
|
|
spec_decoding_position_offsets=spec_decoding_params.
|
|
spec_decoding_position_offsets,
|
|
spec_decoding_packed_mask=spec_decoding_params.
|
|
spec_decoding_packed_mask,
|
|
host_runtime_perf_knobs=attention_params.
|
|
host_runtime_perf_knobs,
|
|
host_context_progress=attention_params.host_context_progress)
|
|
else:
|
|
assert self.paged_kv_cache == False
|
|
|
|
def transpose_for_scores(x):
|
|
new_x_shape = concat([
|
|
shape(x, 0),
|
|
shape(x, 1), self.num_attention_heads,
|
|
self.attention_head_size
|
|
])
|
|
return x.view(new_x_shape).permute([0, 2, 1, 3])
|
|
|
|
query, key, value = split(qkv, self.hidden_size, dim=2)
|
|
query = transpose_for_scores(query)
|
|
key = transpose_for_scores(key)
|
|
value = transpose_for_scores(value)
|
|
|
|
past_key_value = kv_cache_params.get_first_past_key_value()
|
|
if past_key_value is not None:
|
|
|
|
def dequantize_tensor(x, scale):
|
|
# Cast from int8 to dtype
|
|
casted_x = cast(x, self.dtype)
|
|
return casted_x * scale
|
|
|
|
if self.quant_mode.has_int8_kv_cache():
|
|
past_key_value = dequantize_tensor(
|
|
past_key_value, self.kv_dequantization_scale.value)
|
|
|
|
# past_key_value [bs, 2, num_heads, max_seq_len, head_dim]
|
|
past_key, past_value = split(past_key_value, 1, dim=1)
|
|
|
|
key_shape = concat([
|
|
shape(past_key, 0),
|
|
shape(past_key, 2),
|
|
shape(past_key, 3),
|
|
shape(past_key, 4)
|
|
])
|
|
past_key = past_key.view(key_shape, zero_is_placeholder=False)
|
|
past_value = past_value.view(key_shape,
|
|
zero_is_placeholder=False)
|
|
key = concat([past_key, key], dim=2)
|
|
value = concat([past_value, value], dim=2)
|
|
|
|
def merge_caches():
|
|
key_inflated_shape = concat([
|
|
shape(key, 0), 1,
|
|
shape(key, 1),
|
|
shape(key, 2),
|
|
shape(key, 3)
|
|
])
|
|
inflated_key = key.view(key_inflated_shape,
|
|
zero_is_placeholder=False)
|
|
inflated_value = value.view(key_inflated_shape,
|
|
zero_is_placeholder=False)
|
|
past_key_value = concat([inflated_key, inflated_value], dim=1)
|
|
return past_key_value
|
|
|
|
if self.attention_mask_type == AttentionMaskType.causal:
|
|
query_length = shape(query, 2)
|
|
key_length = shape(key, 2)
|
|
starts = concat([0, 0, key_length - query_length, 0])
|
|
sizes = concat([1, 1, query_length, key_length])
|
|
buffer = constant(
|
|
np.expand_dims(
|
|
np.tril(
|
|
np.ones(
|
|
(self.max_position_embeddings,
|
|
self.max_position_embeddings))).astype(bool),
|
|
(0, 1)))
|
|
causal_mask = slice(buffer, starts, sizes)
|
|
|
|
key = key.permute([0, 1, 3, 2])
|
|
with precision("float32"):
|
|
attention_scores = matmul(query, key)
|
|
|
|
if self.attention_mask_type == AttentionMaskType.causal:
|
|
attention_scores = where(causal_mask, attention_scores,
|
|
-10000.0)
|
|
|
|
attention_scores = attention_scores / self.norm_factor
|
|
attention_probs = softmax(attention_scores, dim=-1)
|
|
|
|
context = matmul(attention_probs, value,
|
|
use_fp32_acc=False).permute([0, 2, 1, 3])
|
|
context = context.view(
|
|
concat([shape(context, 0),
|
|
shape(context, 1), self.hidden_size]))
|
|
|
|
past_key_value = merge_caches()
|
|
|
|
if use_cache and self.quant_mode.has_int8_kv_cache():
|
|
past_key_value = quantize_tensor(
|
|
past_key_value, self.kv_quantization_scale.value)
|
|
|
|
# Quantize per token outputs tuple:
|
|
# quantized tensor and scaling factors per token
|
|
context = quantize_per_token(
|
|
context,
|
|
scale_dtype='float16',
|
|
sum_per_token=not self.quant_mode.has_per_group_scaling(),
|
|
sum_dtype='float16')
|
|
|
|
context = self.dense(context, all_reduce_params=all_reduce_params)
|
|
|
|
if use_cache:
|
|
return (context, past_key_value)
|
|
|
|
return context
|