mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: Tlntin <TlntinDeng01@Gmail.com> Co-authored-by: ZHENG, Zhen <zhengzhen.z@qq.com> Co-authored-by: Pham Van Ngoan <ngoanpham1196@gmail.com> Co-authored-by: Nathan Price <nathan@abridge.com> Co-authored-by: Tushar Goel <tushar.goel.ml@gmail.com> Co-authored-by: Mati <132419219+matichon-vultureprime@users.noreply.github.com>
207 lines
6.4 KiB
Python
207 lines
6.4 KiB
Python
from .._utils import get_init_params
|
|
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
|
|
LayerNorm, RmsNorm, RowLinear)
|
|
from ..models.modeling_utils import QuantConfig
|
|
from ..parameter import Parameter
|
|
from .layers import (FP8Linear, FP8RowLinear, Int8SmoothQuantLinear,
|
|
Int8SmoothQuantRowLinear, SmoothQuantAttention,
|
|
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
|
|
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
|
|
WeightOnlyGroupwiseQuantRowLinear,
|
|
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding,
|
|
WeightOnlyQuantRowLinear)
|
|
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo
|
|
|
|
|
|
def quantize_layers(
|
|
model,
|
|
quant_config: QuantConfig,
|
|
quant_map,
|
|
preprocess_init_params=None,
|
|
):
|
|
exclude_modules = quant_config.exclude_modules or [
|
|
'lm_head',
|
|
'router',
|
|
'vocab_embedding',
|
|
'position_embedding',
|
|
'block_embedding',
|
|
]
|
|
|
|
for name, module, parent in model.named_modules_with_parent():
|
|
module_name = name.rsplit('.', 1)[-1]
|
|
if module_name not in exclude_modules:
|
|
quant_cls = None
|
|
for cls in quant_map:
|
|
if isinstance(module, cls):
|
|
quant_cls = quant_map[cls]
|
|
break
|
|
|
|
if quant_cls is None:
|
|
continue
|
|
|
|
init_params = get_init_params(module, quant_cls)
|
|
if "bias" in init_params:
|
|
init_params["bias"] = init_params["bias"] is not None
|
|
if isinstance(module, ColumnLinear):
|
|
init_params[
|
|
"out_features"] = module.out_features * module.tp_size
|
|
elif isinstance(module, RowLinear):
|
|
init_params["in_features"] = module.in_features * module.tp_size
|
|
if preprocess_init_params is not None:
|
|
preprocess_init_params(init_params, name, module)
|
|
quant_layer = quant_cls(**init_params)
|
|
setattr(parent, module_name, quant_layer)
|
|
|
|
setattr(model, 'quant_mode', quant_config.quant_mode)
|
|
return model
|
|
|
|
|
|
def weight_only_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.is_weight_only()
|
|
|
|
quant_map = {
|
|
ColumnLinear: WeightOnlyQuantColumnLinear,
|
|
RowLinear: WeightOnlyQuantRowLinear,
|
|
Embedding: WeightOnlyQuantEmbedding,
|
|
}
|
|
|
|
def preprocess_init_params(init_params, name, module):
|
|
init_params["quant_mode"] = quant_config.quant_mode
|
|
if isinstance(module, ColumnLinear):
|
|
module_name = name.rsplit('.', 1)[-1]
|
|
init_params["transb"] = module_name == "lm_head"
|
|
|
|
quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
preprocess_init_params,
|
|
)
|
|
return model
|
|
|
|
|
|
def weight_only_groupwise_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.is_weight_only()
|
|
|
|
quant_map = {
|
|
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
|
|
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
|
|
}
|
|
|
|
def preprocess_init_params(init_params, name, module):
|
|
init_params["group_size"] = quant_config.group_size
|
|
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
|
|
init_params["zero"] = quant_config.has_zero_point
|
|
init_params[
|
|
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
|
|
|
|
quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
preprocess_init_params,
|
|
)
|
|
return model
|
|
|
|
|
|
def smooth_quantize_ootb(
|
|
model,
|
|
quant_config: QuantConfig,
|
|
):
|
|
quant_map = {
|
|
ColumnLinear: Int8SmoothQuantLinear,
|
|
RowLinear: Int8SmoothQuantRowLinear,
|
|
}
|
|
|
|
quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
)
|
|
return model
|
|
|
|
|
|
def smooth_quantize_plugin(model, quant_mode):
|
|
quant_map = {
|
|
RmsNorm: SmoothQuantRmsNorm,
|
|
LayerNorm: SmoothQuantLayerNorm,
|
|
GatedMLP: SmoothQuantGatedMLP,
|
|
MLP: SmoothQuantMLP,
|
|
Attention: SmoothQuantAttention,
|
|
}
|
|
for name, layer, parent in model.named_modules_with_parent():
|
|
layer_name = name.rsplit('.', 1)[-1]
|
|
if layer_name in ['ln_f']:
|
|
continue
|
|
|
|
quant_cls = None
|
|
for cls in quant_map:
|
|
if isinstance(layer, cls):
|
|
quant_cls = quant_map[cls]
|
|
break
|
|
|
|
if quant_cls is None:
|
|
continue
|
|
|
|
init_params = get_init_params(layer, quant_cls)
|
|
init_params["quant_mode"] = quant_mode
|
|
if isinstance(layer, Attention):
|
|
init_params[
|
|
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
|
|
quant_layer = quant_cls(**init_params)
|
|
setattr(parent, layer_name, quant_layer)
|
|
|
|
setattr(model, 'quant_mode', quant_mode)
|
|
return model
|
|
|
|
|
|
def smooth_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_act_and_weight_quant()
|
|
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
|
|
return smooth_quantize_plugin(model, quant_config.quant_mode)
|
|
else:
|
|
return smooth_quantize_ootb(model, quant_config)
|
|
|
|
|
|
def fp8_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_fp8_qdq()
|
|
|
|
quant_map = {
|
|
ColumnLinear: FP8Linear,
|
|
RowLinear: FP8RowLinear,
|
|
}
|
|
|
|
quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
)
|
|
return model
|
|
|
|
|
|
def kv_cache_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_kv_cache_quant()
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, (Attention, SmoothQuantAttention)):
|
|
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
|
|
dtype='float32')
|
|
|
|
|
|
def quantize(model, quant_config: QuantConfig):
|
|
quant_mode = quant_config.quant_mode
|
|
|
|
if quant_mode.has_fp8_qdq():
|
|
model = fp8_quantize(model, quant_config)
|
|
elif quant_mode.has_act_and_weight_quant():
|
|
model = smooth_quantize(model, quant_config)
|
|
elif quant_mode.is_weight_only():
|
|
if quant_mode.has_per_group_scaling():
|
|
model = weight_only_groupwise_quantize(model, quant_config)
|
|
else:
|
|
model = weight_only_quantize(model, quant_config)
|
|
|
|
if quant_mode.has_kv_cache_quant():
|
|
model = kv_cache_quantize(model, quant_config)
|
|
|
|
return model
|