TensorRT-LLMs/tensorrt_llm/quantization/quantize.py
nv-guomingz adf60a8723
fix:update the default excluded_modules value for fp8rowwise recipe. (#3477)
Signed-off-by: nv-guomingz <37257613+nv-guomingz@users.noreply.github.com>
2025-04-12 16:00:21 +08:00

596 lines
22 KiB
Python

import fnmatch
from typing import Union
import torch
from .._utils import get_init_params
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
LayerNorm, RmsNorm, RowLinear)
from ..layers.moe import MixtureOfExperts
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
from ..parameter import Parameter
from .layers import (FP4Linear, FP4RowLinear, FP8Linear, FP8RowLinear,
Fp8RowwiseAttention, Fp8RowwiseGatedMLP, Fp8RowwiseMLP,
Fp8RowwiseRmsNorm, Int8SmoothQuantLinear,
Int8SmoothQuantRowLinear, QServeAttention, QServeGatedMLP,
QServeMLP, QServeRmsNorm, SmoothQuantAttention,
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear,
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding,
WeightOnlyQuantRowLinear)
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode
def quantize_layers(
model,
quant_config: QuantConfig,
quant_map,
preprocess_init_params=None,
):
exclude_modules = quant_config.exclude_modules or [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
for name, module, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
is_excluded = False
quant_cls = None
# handle exclusion
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
# MoE modules are quantized on their constructor, so they must always
# be re-created with the appropriate quant_mode. When excluded,
# re-create with quant_mode 0.
# We need to handle it specially, we may want to redesign MoE implementation
if isinstance(module, MixtureOfExperts):
quant_cls = type(module)
elif not is_excluded:
for cls in quant_map:
if isinstance(module, cls):
quant_cls = quant_map[cls]
break
if quant_cls:
init_params = get_init_params(module, quant_cls)
if isinstance(module, MixtureOfExperts):
if is_excluded:
quant_mode = QuantMode(0)
else:
quant_mode = quant_config.quant_mode
init_params["quant_mode"] = quant_mode
if "bias" in init_params and not isinstance(module,
MixtureOfExperts):
init_params["bias"] = init_params["bias"] is not None
if isinstance(module, ColumnLinear):
init_params[
"out_features"] = module.out_features * module.tp_size
elif isinstance(module, RowLinear):
init_params["in_features"] = module.in_features * module.tp_size
if preprocess_init_params is not None:
preprocess_init_params(init_params, name, module)
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyQuantColumnLinear,
RowLinear: WeightOnlyQuantRowLinear,
Embedding: WeightOnlyQuantEmbedding,
}
def preprocess_init_params(init_params, name, module):
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(module, ColumnLinear):
module_name = name.rsplit('.', 1)[-1]
init_params["transb"] = module_name == "lm_head"
if "tp_rank" in init_params:
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def weight_only_groupwise_quantize(model,
quant_config: QuantConfig,
model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
MixtureOfExperts: MixtureOfExperts,
}
def preprocess_init_params(init_params, name, module):
init_params["group_size"] = quant_config.group_size
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
init_params["zero"] = quant_config.has_zero_point
init_params[
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
init_params[
"use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ
if "tp_rank" in init_params:
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def smooth_quantize_ootb(
model,
quant_config: QuantConfig,
):
quant_map = {
ColumnLinear: Int8SmoothQuantLinear,
RowLinear: Int8SmoothQuantRowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def smooth_quantize_plugin(model, quant_mode):
quant_map = {
RmsNorm: SmoothQuantRmsNorm,
LayerNorm: SmoothQuantLayerNorm,
GatedMLP: SmoothQuantGatedMLP,
MLP: SmoothQuantMLP,
Attention: SmoothQuantAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def smooth_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_act_and_weight_quant()
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
return smooth_quantize_plugin(model, quant_config.quant_mode)
else:
return smooth_quantize_ootb(model, quant_config)
def fp8_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_qdq()
quant_map = {
ColumnLinear: FP8Linear,
RowLinear: FP8RowLinear,
MixtureOfExperts: MixtureOfExperts,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def fp8_rowwise_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_rowwise()
quant_cls_map = {
RmsNorm: Fp8RowwiseRmsNorm,
GatedMLP: Fp8RowwiseGatedMLP,
MLP: Fp8RowwiseMLP,
Attention: Fp8RowwiseAttention,
}
exclude_modules = list(
set((quant_config.exclude_modules or []) +
['*ln_f', '*ln_embed', '*lm_head']))
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return int(s)
return None
# Meta's LLaMA 3.1 recipe:
# (1) Skip quantization for the first and last Transformer layers
# (2) Skip quantization for the Attention layers
if quant_config.use_meta_recipe:
exclude_modules.extend(['*input_layernorm', '*attention'])
for name, layer, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
if quant_config.use_meta_recipe:
local_layer_idx = extract_layer_idx(name)
mapping = model.config.mapping
layers_range = mapping.pp_layers(model.config.num_hidden_layers)
if mapping.is_first_pp_rank() and local_layer_idx == 0:
continue
if mapping.is_last_pp_rank(
) and local_layer_idx == len(layers_range) - 1:
continue
quant_cls = None
for cls in quant_cls_map:
if isinstance(layer, cls):
quant_cls = quant_cls_map[cls]
break
if quant_cls is None:
continue
is_excluded = False
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
if is_excluded:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
# TODO: These functions should be moved to ModelOpt.
def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor,
group_size: int) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int8
linear_weight = linear_weight.div(
s1_scales.reshape(out_features, 1).to(linear_weight.device))
linear_weight = linear_weight.round()
# assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number
assert (linear_weight.min() >= -128 and linear_weight.max()
<= 127), "Stage 1: Quantized weight out of range"
# Step 2: Quantize the weights to int4
linear_weight = linear_weight.reshape(out_features,
in_features // group_size, group_size)
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round()
assert (linear_weight.min() >= 0 and linear_weight.max()
<= 15), "Stage 2: Quantized weight out of range"
qweight = linear_weight.reshape(out_features, in_features).to(torch.int8)
return qweight
def qserve_quantize_weight_per_channel(
linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int4
s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device)
s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device)
qweight = linear_weight.add(s1_szeros).div(s1_scales).round()
assert (qweight.min() >= 0
and qweight.max() <= 15), "Quantized weight out of range"
return qweight.reshape(out_features, in_features).to(torch.int8)
# Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels.
# Return: processed [qweight, s1_scales, s2_scales, s2_zeros]
def qserve_pack_reorder_per_group(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor, group_size):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
s1_scales = s1_scales.reshape(out_features).to(torch.float16)
s2_szeros = s2_szeros.reshape(out_features,
in_features // group_size).to(torch.int8)
s2_scales = s2_scales.reshape(out_features,
in_features // group_size).to(torch.int8)
# Step 3: Pack the quantized weights to real quantized weights
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous().to(torch.int8))
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features,
in_features // 2)
outputs.append(W_unpack_repacked)
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
# ---- Pack the scales ---- #
outputs.append(s1_scales.reshape(out_features))
s2_scales = (s2_scales.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32,
32)
s2_scales = (s2_scales.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size,
out_features).contiguous()
outputs.append(s2_scales)
# ---- Pack the zeros ---- #
s2_szeros = (s2_szeros.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32,
32)
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features).contiguous())
# (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales,
# We convert the s2_zeros -> -s2_zeros * s2_scales
s2_szeros = (-s2_szeros).int() # It has been pre-scaled in DeepCompressor
s2_szeros = s2_szeros.to(torch.int8)
outputs.append(s2_szeros)
return outputs
def qserve_pack_reorder_per_channel(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous())
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features //
2).contiguous()
outputs.append(W_unpack_repacked)
# ---- Pack the scales and zeros ---- #
s1_scales = s1_scales.reshape(out_features).contiguous()
outputs.append(s1_scales.half())
s1_szeros = s1_szeros.reshape(out_features).contiguous().half()
outputs.append(s1_szeros)
return outputs
# TODO: Duplicates smooth_quantize and quantize_layers
def qserve_quantize(model, quant_config: QuantConfig):
quant_mode = quant_config.quant_mode
assert quant_config.quant_mode.is_qserve_w4a8()
quant_map = {
RmsNorm: QServeRmsNorm,
LayerNorm: QServeRmsNorm,
GatedMLP: QServeGatedMLP,
MLP: QServeMLP,
Attention: QServeAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def fp4_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_nvfp4()
quant_map = {
ColumnLinear: FP4Linear,
RowLinear: FP4RowLinear,
MixtureOfExperts: MixtureOfExperts,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
# Now consider the kv cache is enabled for all layers
def kv_cache_quantize(model):
for name, module in model.named_modules():
if isinstance(module,
(Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
# for dequant
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
# for quant
module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
return model
def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
for name, module, parent in model.named_modules_with_parent():
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
layer_quant_mode = quant_config.layer_quant_mode(name)
else:
layer_quant_mode = quant_config.layer_quant_mode
if layer_quant_mode == QuantMode(0):
continue
layer_quant_cfg = quant_config._get_quant_cfg(name)
if layer_quant_mode.has_fp8_qdq():
module = fp8_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_fp8_rowwise():
module = fp8_rowwise_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_qserve_w4a8():
module = qserve_quantize(module, quant_config)
elif layer_quant_mode.has_nvfp4():
module = fp4_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_act_and_weight_quant():
module = smooth_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_weight_only():
if layer_quant_mode.has_per_group_scaling():
module = weight_only_groupwise_quantize(module, layer_quant_cfg,
model.config)
else:
module = weight_only_quantize(module, layer_quant_cfg,
model.config)
if parent is not None: # for per layer
module_name = name.rsplit('.', 1)[-1]
setattr(parent, module_name, module)
else: # for all layer
model = module
break
if quant_config.quant_mode.has_kv_cache_quant():
model = kv_cache_quantize(model)
setattr(model, 'quant_mode', quant_config.quant_mode)
return model