TensorRT-LLMs/tensorrt_llm/quantization/quantize.py
Aurelien Chartier 1389f5a4d3
feat: Add support for fp8 rowwise quantization (#4876)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Co-authored-by: aikitoria <151776613+aikitoria@users.noreply.github.com>
2025-06-14 06:37:48 -07:00

604 lines
22 KiB
Python

import fnmatch
from typing import Union
import torch
from .._utils import get_init_params
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
LayerNorm, RmsNorm, RowLinear)
from ..layers.moe import MixtureOfExperts
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
from ..parameter import Parameter
# isort: off
from .layers import (
FP4Linear, FP4RowLinear, FP8Linear, FP8RowLinear, Fp8RowwiseAttention,
Fp8RowwiseGatedMLP, Fp8RowwiseLayerNorm, Fp8RowwiseMLP, Fp8RowwiseRmsNorm,
Int8SmoothQuantLinear, Int8SmoothQuantRowLinear, QServeAttention,
QServeGatedMLP, QServeMLP, QServeRmsNorm, SmoothQuantAttention,
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear, WeightOnlyQuantColumnLinear,
WeightOnlyQuantEmbedding, WeightOnlyQuantRowLinear)
# isort: on
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode
def quantize_layers(
model,
quant_config: QuantConfig,
quant_map,
preprocess_init_params=None,
):
exclude_modules = quant_config.exclude_modules
if exclude_modules is None:
exclude_modules = [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
for name, module, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
is_excluded = False
quant_cls = None
# handle exclusion
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
# MoE modules are quantized on their constructor, so they must always
# be re-created with the appropriate quant_mode. When excluded,
# re-create with quant_mode 0.
# We need to handle it specially, we may want to redesign MoE implementation
if isinstance(module, MixtureOfExperts):
quant_cls = type(module)
elif not is_excluded:
for cls in quant_map:
if isinstance(module, cls):
quant_cls = quant_map[cls]
break
if quant_cls:
init_params = get_init_params(module, quant_cls)
if isinstance(module, MixtureOfExperts):
if is_excluded:
quant_mode = QuantMode(0)
else:
quant_mode = quant_config.quant_mode
init_params["quant_mode"] = quant_mode
if "bias" in init_params and not isinstance(module,
MixtureOfExperts):
init_params["bias"] = init_params["bias"] is not None
if isinstance(module, ColumnLinear):
init_params[
"out_features"] = module.out_features * module.tp_size
elif isinstance(module, RowLinear):
init_params["in_features"] = module.in_features * module.tp_size
if preprocess_init_params is not None:
preprocess_init_params(init_params, name, module)
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyQuantColumnLinear,
RowLinear: WeightOnlyQuantRowLinear,
Embedding: WeightOnlyQuantEmbedding,
}
def preprocess_init_params(init_params, name, module):
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(module, ColumnLinear):
module_name = name.rsplit('.', 1)[-1]
init_params["transb"] = module_name == "lm_head"
if "tp_rank" in init_params:
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def weight_only_groupwise_quantize(model,
quant_config: QuantConfig,
model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
MixtureOfExperts: MixtureOfExperts,
}
def preprocess_init_params(init_params, name, module):
init_params["group_size"] = quant_config.group_size
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
init_params["zero"] = quant_config.has_zero_point
init_params[
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
init_params[
"use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ
if "tp_rank" in init_params:
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def smooth_quantize_ootb(
model,
quant_config: QuantConfig,
):
quant_map = {
ColumnLinear: Int8SmoothQuantLinear,
RowLinear: Int8SmoothQuantRowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def smooth_quantize_plugin(model, quant_mode):
quant_map = {
RmsNorm: SmoothQuantRmsNorm,
LayerNorm: SmoothQuantLayerNorm,
GatedMLP: SmoothQuantGatedMLP,
MLP: SmoothQuantMLP,
Attention: SmoothQuantAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def smooth_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_act_and_weight_quant()
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
return smooth_quantize_plugin(model, quant_config.quant_mode)
else:
return smooth_quantize_ootb(model, quant_config)
def fp8_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_qdq()
quant_map = {
ColumnLinear: FP8Linear,
RowLinear: FP8RowLinear,
MixtureOfExperts: MixtureOfExperts,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def fp8_rowwise_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_rowwise()
quant_cls_map = {
RmsNorm: Fp8RowwiseRmsNorm,
LayerNorm: Fp8RowwiseLayerNorm,
GatedMLP: Fp8RowwiseGatedMLP,
MLP: Fp8RowwiseMLP,
Attention: Fp8RowwiseAttention,
}
exclude_modules = quant_config.exclude_modules
if exclude_modules is None:
exclude_modules = []
# Always exclude these modules for FP8 rowwise
exclude_modules = list(
set(exclude_modules + ['*ln_f', '*ln_embed', '*lm_head']))
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return int(s)
return None
# Meta's LLaMA 3.1 recipe:
# (1) Skip quantization for the first and last Transformer layers
# (2) Skip quantization for the Attention layers
if quant_config.use_meta_recipe:
exclude_modules.extend(['*input_layernorm', '*attention'])
for name, layer, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
if quant_config.use_meta_recipe:
local_layer_idx = extract_layer_idx(name)
mapping = model.config.mapping
layers_range = mapping.pp_layers(model.config.num_hidden_layers)
if mapping.is_first_pp_rank() and local_layer_idx == 0:
continue
if mapping.is_last_pp_rank(
) and local_layer_idx == len(layers_range) - 1:
continue
quant_cls = None
for cls in quant_cls_map:
if isinstance(layer, cls):
quant_cls = quant_cls_map[cls]
break
if quant_cls is None:
continue
is_excluded = False
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
if is_excluded:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
# TODO: These functions should be moved to ModelOpt.
def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor,
group_size: int) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int8
linear_weight = linear_weight.div(
s1_scales.reshape(out_features, 1).to(linear_weight.device))
linear_weight = linear_weight.round()
# assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number
assert (linear_weight.min() >= -128 and linear_weight.max()
<= 127), "Stage 1: Quantized weight out of range"
# Step 2: Quantize the weights to int4
linear_weight = linear_weight.reshape(out_features,
in_features // group_size, group_size)
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round()
assert (linear_weight.min() >= 0 and linear_weight.max()
<= 15), "Stage 2: Quantized weight out of range"
qweight = linear_weight.reshape(out_features, in_features).to(torch.int8)
return qweight
def qserve_quantize_weight_per_channel(
linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int4
s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device)
s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device)
qweight = linear_weight.add(s1_szeros).div(s1_scales).round()
assert (qweight.min() >= 0
and qweight.max() <= 15), "Quantized weight out of range"
return qweight.reshape(out_features, in_features).to(torch.int8)
# Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels.
# Return: processed [qweight, s1_scales, s2_scales, s2_zeros]
def qserve_pack_reorder_per_group(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor, group_size):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
s1_scales = s1_scales.reshape(out_features).to(torch.float16)
s2_szeros = s2_szeros.reshape(out_features,
in_features // group_size).to(torch.int8)
s2_scales = s2_scales.reshape(out_features,
in_features // group_size).to(torch.int8)
# Step 3: Pack the quantized weights to real quantized weights
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous().to(torch.int8))
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features,
in_features // 2)
outputs.append(W_unpack_repacked)
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
# ---- Pack the scales ---- #
outputs.append(s1_scales.reshape(out_features))
s2_scales = (s2_scales.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32,
32)
s2_scales = (s2_scales.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size,
out_features).contiguous()
outputs.append(s2_scales)
# ---- Pack the zeros ---- #
s2_szeros = (s2_szeros.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32,
32)
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features).contiguous())
# (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales,
# We convert the s2_zeros -> -s2_zeros * s2_scales
s2_szeros = (-s2_szeros).int() # It has been pre-scaled in DeepCompressor
s2_szeros = s2_szeros.to(torch.int8)
outputs.append(s2_szeros)
return outputs
def qserve_pack_reorder_per_channel(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous())
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features //
2).contiguous()
outputs.append(W_unpack_repacked)
# ---- Pack the scales and zeros ---- #
s1_scales = s1_scales.reshape(out_features).contiguous()
outputs.append(s1_scales.half())
s1_szeros = s1_szeros.reshape(out_features).contiguous().half()
outputs.append(s1_szeros)
return outputs
# TODO: Duplicates smooth_quantize and quantize_layers
def qserve_quantize(model, quant_config: QuantConfig):
quant_mode = quant_config.quant_mode
assert quant_config.quant_mode.is_qserve_w4a8()
quant_map = {
RmsNorm: QServeRmsNorm,
LayerNorm: QServeRmsNorm,
GatedMLP: QServeGatedMLP,
MLP: QServeMLP,
Attention: QServeAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def fp4_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_nvfp4()
quant_map = {
ColumnLinear: FP4Linear,
RowLinear: FP4RowLinear,
MixtureOfExperts: MixtureOfExperts,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
# Now consider the kv cache is enabled for all layers
def kv_cache_quantize(model):
for name, module in model.named_modules():
if isinstance(module,
(Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
# for dequant
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
# for quant
module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
return model
def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
for name, module, parent in model.named_modules_with_parent():
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
layer_quant_mode = quant_config.layer_quant_mode(name)
else:
layer_quant_mode = quant_config.layer_quant_mode
if layer_quant_mode == QuantMode(0):
continue
layer_quant_cfg = quant_config._get_quant_cfg(name)
if layer_quant_mode.has_fp8_qdq():
module = fp8_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_fp8_rowwise():
module = fp8_rowwise_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_qserve_w4a8():
module = qserve_quantize(module, quant_config)
elif layer_quant_mode.has_nvfp4():
module = fp4_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_act_and_weight_quant():
module = smooth_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_weight_only():
if layer_quant_mode.has_per_group_scaling():
module = weight_only_groupwise_quantize(module, layer_quant_cfg,
model.config)
else:
module = weight_only_quantize(module, layer_quant_cfg,
model.config)
if parent is not None: # for per layer
module_name = name.rsplit('.', 1)[-1]
setattr(parent, module_name, module)
else: # for all layer
model = module
break
if quant_config.quant_mode.has_kv_cache_quant():
model = kv_cache_quantize(model)
setattr(model, 'quant_mode', quant_config.quant_mode)
return model