TensorRT-LLMs/tensorrt_llm/quantization/quantize.py
Kaiyu Xie e06f537e08
Update TensorRT-LLM (#1019)
* Update TensorRT-LLM

---------

Co-authored-by: erenup <ping.nie@pku.edu.cn>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-31 21:55:32 +08:00

350 lines
15 KiB
Python

import numpy as np
from ..layers import MLP, ColumnLinear, GatedMLP, LayerNorm, RmsNorm, RowLinear
from ..parameter import Parameter
from .layers import (FP8Linear, FP8RowLinear, Int8SmoothQuantLinear,
Int8SmoothQuantRowLinear, SmoothQuantAttention,
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear,
WeightOnlyQuantColumnLinear, WeightOnlyQuantRowLinear)
def weight_only_quantize(model,
quant_mode,
exclude_modules=None,
current_key_name=None):
assert quant_mode.is_weight_only()
exclude_modules = ['lm_head'
] if exclude_modules is None else exclude_modules
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
weight_only_quantize(module, quant_mode, exclude_modules,
current_key_name)
if isinstance(module, ColumnLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyQuantColumnLinear(
in_features=module.in_features,
out_features=module.out_features * module.tp_size,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size,
gather_output=module.gather_output,
quant_mode=quant_mode)
elif isinstance(module, RowLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyQuantRowLinear(
in_features=module.in_features * module.tp_size,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size,
quant_mode=quant_mode)
current_key_name.pop(-1)
setattr(model, 'quant_mode', quant_mode)
return model
def weight_only_groupwise_quantize(model,
quant_mode,
quant_algo='W4A16_AWQ',
group_size=128,
pre_quant_scale=False,
zero=False,
exclude_modules=None,
current_key_name=None):
assert quant_mode.is_weight_only()
exclude_modules = ['lm_head'
] if exclude_modules is None else exclude_modules
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
weight_only_groupwise_quantize(module, quant_mode, quant_algo,
group_size, pre_quant_scale, zero,
exclude_modules, current_key_name)
if isinstance(module, ColumnLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyGroupwiseQuantColumnLinear(
in_features=module.in_features,
out_features=module.out_features * module.tp_size,
group_size=group_size,
pre_quant_scale=pre_quant_scale,
zero=zero,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size,
gather_output=module.gather_output)
elif isinstance(module, RowLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyGroupwiseQuantRowLinear(
in_features=module.in_features * module.tp_size,
out_features=module.out_features,
group_size=group_size,
pre_quant_scale=pre_quant_scale,
zero=zero,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size)
current_key_name.pop(-1)
return model
def smooth_quantize_ootb(model,
quant_mode,
current_key_name=None,
exclude_modules=None):
exclude_modules = ['lm_head'
] if exclude_modules is None else exclude_modules
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
smooth_quantize_ootb(module, quant_mode, exclude_modules,
current_key_name)
if isinstance(module, ColumnLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = Int8SmoothQuantLinear(
module.in_features, module.out_features * module.tp_size,
module.bias, module.dtype, module.tp_group, module.tp_size,
module.gather_output)
elif isinstance(module, RowLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = Int8SmoothQuantRowLinear(
module.in_features * module.tp_size, module.out_features,
module.bias, module.dtype, module.tp_group, module.tp_size)
current_key_name.pop(-1)
return model
def smooth_quantize_plugin(model, quant_mode):
for layer in model.transformer.layers:
config = layer.config
assert hasattr(layer,
"input_layernorm"), "The layer has no input_layernorm"
quant_norm_cls = None
if isinstance(layer.input_layernorm, RmsNorm):
quant_norm_cls = SmoothQuantRmsNorm
elif isinstance(layer.input_layernorm, LayerNorm):
quant_norm_cls = SmoothQuantLayerNorm
assert quant_norm_cls is not None
layer.input_layernorm = quant_norm_cls(
normalized_shape=config.hidden_size,
dtype=config.dtype,
quant_mode=quant_mode)
assert hasattr(layer, "attention"), "The layer has no attention"
qkv_bias = layer.attention.qkv.bias is not None
dense_bias = layer.attention.dense.bias is not None
layer.attention = SmoothQuantAttention(
config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
num_layers=config.num_hidden_layers,
dtype=config.dtype,
attention_mask_type=layer.attention.attention_mask_type,
position_embedding_type=layer.attention.position_embedding_type,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
tp_rank=config.mapping.tp_rank,
quant_mode=quant_mode,
bias=(qkv_bias and dense_bias),
qkv_bias_only=(qkv_bias and not dense_bias))
assert hasattr(layer, "mlp"), "The layer has no mlp"
mlp_norm_cls = None
if isinstance(layer.mlp, GatedMLP):
mlp_norm_cls = SmoothQuantGatedMLP
elif isinstance(layer.mlp, MLP):
mlp_norm_cls = SmoothQuantMLP
layer.mlp = mlp_norm_cls(hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
hidden_act=config.hidden_act,
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=quant_mode,
bias=layer.mlp.bias)
assert hasattr(
layer,
"post_layernorm"), "The layer has no post_rmspost_layernormnorm"
quant_norm_cls = None
if isinstance(layer.post_layernorm, RmsNorm):
quant_norm_cls = SmoothQuantRmsNorm
elif isinstance(layer.post_layernorm, LayerNorm):
quant_norm_cls = SmoothQuantLayerNorm
assert quant_norm_cls is not None
layer.post_layernorm = quant_norm_cls(
normalized_shape=config.hidden_size,
dtype=config.dtype,
quant_mode=quant_mode)
return model
def smooth_quantize(model, quant_mode, use_plugin=False):
assert quant_mode.has_act_and_weight_quant()
if use_plugin:
return smooth_quantize_plugin(model, quant_mode)
else:
return smooth_quantize_ootb(model, quant_mode)
def _get_dummy_quant_scales(num_layers):
return {
'lm_head_act': 0.99,
'lm_head_weights': 0.99,
'fc_act': [0.99 for _ in range(num_layers)],
'fc_weights': [0.99 for _ in range(num_layers)],
'gate_act': [0.99 for _ in range(num_layers)],
'gate_weights': [0.99 for _ in range(num_layers)],
'proj_act': [0.99 for _ in range(num_layers)],
'proj_weights': [0.99 for _ in range(num_layers)],
'qkv_act': [0.99 for _ in range(num_layers)],
'qkv_weights': [0.99 for _ in range(num_layers)],
'qkv_output': [5.0 for _ in range(num_layers)],
'dense_act': [0.99 for _ in range(num_layers)],
'dense_weights': [0.99 for _ in range(num_layers)],
}
def _quantize_layer(layer, layer_idx, quant_mode, quant_scales):
assert hasattr(layer, "mlp"), "The layer has no mlp"
fake_fp8_sf_dt = np.float32
assert isinstance(layer.mlp.fc, (FP8Linear, FP8RowLinear))
assert isinstance(layer.mlp.proj, (FP8Linear, FP8RowLinear))
layer.mlp.fc.activation_scaling_factor.value = np.array(
[quant_scales['fc_act'][layer_idx]], dtype=fake_fp8_sf_dt)
layer.mlp.fc.weights_scaling_factor.value = np.array(
[quant_scales['fc_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
layer.mlp.proj.activation_scaling_factor.value = np.array(
[quant_scales['proj_act'][layer_idx]], dtype=fake_fp8_sf_dt)
layer.mlp.proj.weights_scaling_factor.value = np.array(
[quant_scales['proj_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
if hasattr(layer.mlp, 'gate'):
assert isinstance(layer.mlp.gate, (FP8Linear, FP8RowLinear))
layer.mlp.gate.activation_scaling_factor.value = np.array(
[quant_scales['gate_act'][layer_idx]], dtype=fake_fp8_sf_dt)
layer.mlp.gate.weights_scaling_factor.value = np.array(
[quant_scales['gate_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
assert hasattr(layer, "attention"), "The layer has no attention"
assert isinstance(layer.attention.qkv, (FP8Linear, FP8RowLinear))
assert isinstance(layer.attention.dense, (FP8Linear, FP8RowLinear))
layer.attention.qkv.activation_scaling_factor.value = np.array(
[quant_scales['qkv_act'][layer_idx]], dtype=fake_fp8_sf_dt)
layer.attention.qkv.weights_scaling_factor.value = np.array(
[quant_scales['qkv_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
if quant_mode.has_fp8_kv_cache():
layer.attention.kv_cache_scaling_factor.value = np.array(
[1.0], dtype=fake_fp8_sf_dt)
layer.attention.dense.activation_scaling_factor.value = np.array(
[quant_scales['dense_act'][layer_idx]], dtype=fake_fp8_sf_dt)
layer.attention.dense.weights_scaling_factor.value = np.array(
[quant_scales['dense_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
return layer
def default_fp8_quantize(model, quant_mode, quant_scales: dict = None):
"""
Quantize all linear layers (i.e., MLP, Attention QKV/Dense) and KV cache IO with dummy scales
This is used by benchmark script and therefore is intentionally decoupled from AMMO toolkit
"""
if quant_scales is None:
quant_scales = _get_dummy_quant_scales(model.config.num_hidden_layers)
assert model.config.quant_mode == quant_mode, "Quant setting not consistent with model init setting"
use_fp8_qdq = quant_mode.has_fp8_qdq()
assert use_fp8_qdq
for layer_idx, layer in enumerate(model.transformer.layers):
layer = _quantize_layer(layer, layer_idx, quant_mode, quant_scales)
# TODO: add lm_head
return model
def quantize_kv_cache(model, quant_mode):
for layer in model.transformer.layers:
if quant_mode.has_kv_cache_quant():
layer.attention.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
else:
layer.attention.register_parameter('kv_cache_scaling_factor', None)
return model
def quantize(model, quant_mode, **kwargs):
quantize_kv_cache(model, quant_mode)
if quant_mode.has_act_and_weight_quant():
if 'sq_use_plugin' in kwargs and kwargs['sq_use_plugin']:
smooth_quantize(model, quant_mode, use_plugin=True)
else:
smooth_quantize(model, quant_mode)
elif quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache():
# FIXME(guomingz): make llama use AMMO 0.7.0 new checkpoint directly
if model.config.architecture in [
'LlamaForCausalLM', 'InternLMForCausalLM', 'MedusaForCausalLM'
]:
default_fp8_quantize(model, quant_mode)
elif quant_mode.is_weight_only():
if quant_mode.has_per_group_scaling():
kwargs = {
k: kwargs[k]
for k in [
'quant_algo', 'group_size', 'zero', 'pre_quant_scale',
'exclude_modules'
]
}
weight_only_groupwise_quantize(model, quant_mode, **kwargs)
else:
kwargs = {k: kwargs[k] for k in ['exclude_modules']}
weight_only_quantize(model, quant_mode, **kwargs)