TensorRT-LLMs/tensorrt_llm/quantization/quantize.py
Kaiyu Xie 8681b3a4c0
open source 4dbf696ae9b74a26829d120b67ab8443d70c8e58 (#2297)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
2024-10-08 12:19:19 +02:00

331 lines
11 KiB
Python

import fnmatch
from typing import Union
from .._utils import get_init_params
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
LayerNorm, RmsNorm, RowLinear)
from ..layers.moe import MixtureOfExperts
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
from ..parameter import Parameter
from .layers import (FP8Linear, FP8RowLinear, Fp8RowwiseGatedMLP, Fp8RowwiseMLP,
Fp8RowwiseRmsNorm, Int8SmoothQuantLinear,
Int8SmoothQuantRowLinear, SmoothQuantAttention,
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear,
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding,
WeightOnlyQuantRowLinear)
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode
def quantize_layers(
model,
quant_config: QuantConfig,
quant_map,
preprocess_init_params=None,
):
exclude_modules = quant_config.exclude_modules or [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
for name, module, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
is_excluded = False
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
# MOE module will be quantize when initialization.
# We need to re-initialize a FP version of MOE module.
if isinstance(module, MixtureOfExperts):
init_params = get_init_params(module, MixtureOfExperts)
init_params["quant_mode"] = QuantMode(0)
original_layer = MixtureOfExperts(**init_params)
if parent is not None:
setattr(parent, module_name, original_layer)
else:
model = original_layer
break
if not is_excluded:
quant_cls = None
for cls in quant_map:
if isinstance(module, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(module, quant_cls)
if "bias" in init_params:
init_params["bias"] = init_params["bias"] is not None
if isinstance(module, ColumnLinear):
init_params[
"out_features"] = module.out_features * module.tp_size
elif isinstance(module, RowLinear):
init_params["in_features"] = module.in_features * module.tp_size
if preprocess_init_params is not None:
preprocess_init_params(init_params, name, module)
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except Exception:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyQuantColumnLinear,
RowLinear: WeightOnlyQuantRowLinear,
Embedding: WeightOnlyQuantEmbedding,
}
def preprocess_init_params(init_params, name, module):
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(module, ColumnLinear):
module_name = name.rsplit('.', 1)[-1]
init_params["transb"] = module_name == "lm_head"
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def weight_only_groupwise_quantize(model,
quant_config: QuantConfig,
model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except Exception:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
}
def preprocess_init_params(init_params, name, module):
init_params["group_size"] = quant_config.group_size
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
init_params["zero"] = quant_config.has_zero_point
init_params[
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def smooth_quantize_ootb(
model,
quant_config: QuantConfig,
):
quant_map = {
ColumnLinear: Int8SmoothQuantLinear,
RowLinear: Int8SmoothQuantRowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def smooth_quantize_plugin(model, quant_mode):
quant_map = {
RmsNorm: SmoothQuantRmsNorm,
LayerNorm: SmoothQuantLayerNorm,
GatedMLP: SmoothQuantGatedMLP,
MLP: SmoothQuantMLP,
Attention: SmoothQuantAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def smooth_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_act_and_weight_quant()
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
return smooth_quantize_plugin(model, quant_config.quant_mode)
else:
return smooth_quantize_ootb(model, quant_config)
def fp8_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_qdq()
quant_map = {
ColumnLinear: FP8Linear,
RowLinear: FP8RowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def fp8_rowwise_quantize(model, quant_config: QuantConfig, model_config=None):
assert quant_config.quant_mode.has_fp8_rowwise()
try:
model_cfg = model.config
except Exception:
model_cfg = model_config
quant_map = {
RmsNorm: Fp8RowwiseRmsNorm,
GatedMLP: Fp8RowwiseGatedMLP,
MLP: Fp8RowwiseMLP,
}
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return int(s)
return None
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
layer_idx = extract_layer_idx(name)
if layer_name in ['ln_f', 'ln_embed'] or "input_layernorm" in name:
continue
# Meta's Fp8 recipe
mapping = model_cfg.mapping
layers_range = mapping.pp_layers(model_cfg.num_hidden_layers)
is_first_layer = mapping.is_first_pp_rank() and layer_idx == 0
is_last_layer = mapping.is_last_pp_rank(
) and layer_idx == len(layers_range) - 1
if is_first_layer or is_last_layer:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_config.quant_mode
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
# Now consider the kv cache is enabled for all layers
def kv_cache_quantize(model):
for name, module in model.named_modules():
if isinstance(module, (Attention, SmoothQuantAttention)):
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
return model
def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
quant_mode = quant_config.layer_quant_mode
for name, module, parent in model.named_modules_with_parent():
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
if name in quant_mode.keys():
layer_quant_mode = quant_mode[name]
else:
continue
else:
layer_quant_mode = quant_mode
if layer_quant_mode == QuantMode(0):
continue
layer_quant_cfg = quant_config.get_quant_cfg(name)
if layer_quant_mode.has_fp8_qdq():
module = fp8_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_fp8_rowwise():
module = fp8_rowwise_quantize(module, layer_quant_cfg, model.config)
elif layer_quant_mode.has_act_and_weight_quant():
module = smooth_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_weight_only():
if layer_quant_mode.has_per_group_scaling():
module = weight_only_groupwise_quantize(module, layer_quant_cfg,
model.config)
else:
module = weight_only_quantize(module, layer_quant_cfg,
model.config)
if parent is not None: # for per layer
module_name = name.rsplit('.', 1)[-1]
setattr(parent, module_name, module)
else: # for all layer
model = module
break
if quant_config.quant_mode.has_kv_cache_quant():
model = kv_cache_quantize(model)
setattr(model, 'quant_mode', quant_config.quant_mode)
return model