mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
597 lines
22 KiB
Python
597 lines
22 KiB
Python
import fnmatch
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
from .._utils import get_init_params
|
|
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
|
|
LayerNorm, RmsNorm, RowLinear)
|
|
from ..layers.moe import MixtureOfExperts
|
|
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
|
|
from ..parameter import Parameter
|
|
from .layers import (FP4Linear, FP4RowLinear, FP8Linear, FP8RowLinear,
|
|
Fp8RowwiseAttention, Fp8RowwiseGatedMLP, Fp8RowwiseMLP,
|
|
Fp8RowwiseRmsNorm, Int8SmoothQuantLinear,
|
|
Int8SmoothQuantRowLinear, QServeAttention, QServeGatedMLP,
|
|
QServeMLP, QServeRmsNorm, SmoothQuantAttention,
|
|
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
|
|
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
|
|
WeightOnlyGroupwiseQuantRowLinear,
|
|
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding,
|
|
WeightOnlyQuantRowLinear)
|
|
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode
|
|
|
|
|
|
def quantize_layers(
|
|
model,
|
|
quant_config: QuantConfig,
|
|
quant_map,
|
|
preprocess_init_params=None,
|
|
):
|
|
exclude_modules = quant_config.exclude_modules or [
|
|
'*lm_head',
|
|
'*router',
|
|
'*vocab_embedding',
|
|
'*position_embedding',
|
|
'*block_embedding',
|
|
'*shared_expert_gate',
|
|
]
|
|
|
|
for name, module, parent in model.named_modules_with_parent():
|
|
module_name = name.rsplit('.', 1)[-1]
|
|
is_excluded = False
|
|
quant_cls = None
|
|
|
|
# handle exclusion
|
|
for exclude_module in exclude_modules:
|
|
if fnmatch.fnmatchcase(name, exclude_module):
|
|
is_excluded = True
|
|
break
|
|
|
|
# MoE modules are quantized on their constructor, so they must always
|
|
# be re-created with the appropriate quant_mode. When excluded,
|
|
# re-create with quant_mode 0.
|
|
# We need to handle it specially, we may want to redesign MoE implementation
|
|
if isinstance(module, MixtureOfExperts):
|
|
quant_cls = type(module)
|
|
elif not is_excluded:
|
|
for cls in quant_map:
|
|
if isinstance(module, cls):
|
|
quant_cls = quant_map[cls]
|
|
break
|
|
|
|
if quant_cls:
|
|
init_params = get_init_params(module, quant_cls)
|
|
if isinstance(module, MixtureOfExperts):
|
|
if is_excluded:
|
|
quant_mode = QuantMode(0)
|
|
else:
|
|
quant_mode = quant_config.quant_mode
|
|
init_params["quant_mode"] = quant_mode
|
|
if "bias" in init_params and not isinstance(module,
|
|
MixtureOfExperts):
|
|
init_params["bias"] = init_params["bias"] is not None
|
|
if isinstance(module, ColumnLinear):
|
|
init_params[
|
|
"out_features"] = module.out_features * module.tp_size
|
|
elif isinstance(module, RowLinear):
|
|
init_params["in_features"] = module.in_features * module.tp_size
|
|
if preprocess_init_params is not None:
|
|
preprocess_init_params(init_params, name, module)
|
|
quant_layer = quant_cls(**init_params)
|
|
if parent is not None:
|
|
setattr(parent, module_name, quant_layer)
|
|
else:
|
|
model = quant_layer
|
|
|
|
setattr(model, 'quant_mode', quant_config.quant_mode)
|
|
return model
|
|
|
|
|
|
def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
|
|
assert quant_config.quant_mode.is_weight_only()
|
|
|
|
try:
|
|
model_cfg = model.config
|
|
except AttributeError:
|
|
model_cfg = model_config
|
|
|
|
quant_map = {
|
|
ColumnLinear: WeightOnlyQuantColumnLinear,
|
|
RowLinear: WeightOnlyQuantRowLinear,
|
|
Embedding: WeightOnlyQuantEmbedding,
|
|
}
|
|
|
|
def preprocess_init_params(init_params, name, module):
|
|
init_params["quant_mode"] = quant_config.quant_mode
|
|
if isinstance(module, ColumnLinear):
|
|
module_name = name.rsplit('.', 1)[-1]
|
|
init_params["transb"] = module_name == "lm_head"
|
|
if "tp_rank" in init_params:
|
|
init_params["tp_rank"] = model_cfg.mapping.tp_rank
|
|
|
|
model = quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
preprocess_init_params,
|
|
)
|
|
return model
|
|
|
|
|
|
def weight_only_groupwise_quantize(model,
|
|
quant_config: QuantConfig,
|
|
model_config=None):
|
|
assert quant_config.quant_mode.is_weight_only()
|
|
|
|
try:
|
|
model_cfg = model.config
|
|
except AttributeError:
|
|
model_cfg = model_config
|
|
|
|
quant_map = {
|
|
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
|
|
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
|
|
MixtureOfExperts: MixtureOfExperts,
|
|
}
|
|
|
|
def preprocess_init_params(init_params, name, module):
|
|
init_params["group_size"] = quant_config.group_size
|
|
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
|
|
init_params["zero"] = quant_config.has_zero_point
|
|
init_params[
|
|
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
|
|
init_params[
|
|
"use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ
|
|
if "tp_rank" in init_params:
|
|
init_params["tp_rank"] = model_cfg.mapping.tp_rank
|
|
|
|
model = quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
preprocess_init_params,
|
|
)
|
|
return model
|
|
|
|
|
|
def smooth_quantize_ootb(
|
|
model,
|
|
quant_config: QuantConfig,
|
|
):
|
|
quant_map = {
|
|
ColumnLinear: Int8SmoothQuantLinear,
|
|
RowLinear: Int8SmoothQuantRowLinear,
|
|
}
|
|
|
|
model = quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
)
|
|
return model
|
|
|
|
|
|
def smooth_quantize_plugin(model, quant_mode):
|
|
quant_map = {
|
|
RmsNorm: SmoothQuantRmsNorm,
|
|
LayerNorm: SmoothQuantLayerNorm,
|
|
GatedMLP: SmoothQuantGatedMLP,
|
|
MLP: SmoothQuantMLP,
|
|
Attention: SmoothQuantAttention,
|
|
}
|
|
for name, layer, parent in model.named_modules_with_parent():
|
|
layer_name = name.rsplit('.', 1)[-1]
|
|
if layer_name in ['ln_f', 'ln_embed']:
|
|
continue
|
|
|
|
quant_cls = None
|
|
for cls in quant_map:
|
|
if isinstance(layer, cls):
|
|
quant_cls = quant_map[cls]
|
|
break
|
|
|
|
if quant_cls is None:
|
|
continue
|
|
|
|
init_params = get_init_params(layer, quant_cls)
|
|
init_params["quant_mode"] = quant_mode
|
|
if isinstance(layer, Attention):
|
|
init_params[
|
|
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
|
|
quant_layer = quant_cls(**init_params)
|
|
if parent is not None:
|
|
setattr(parent, layer_name, quant_layer)
|
|
else:
|
|
model = quant_layer
|
|
|
|
setattr(model, 'quant_mode', quant_mode)
|
|
return model
|
|
|
|
|
|
def smooth_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_act_and_weight_quant()
|
|
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
|
|
return smooth_quantize_plugin(model, quant_config.quant_mode)
|
|
else:
|
|
return smooth_quantize_ootb(model, quant_config)
|
|
|
|
|
|
def fp8_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_fp8_qdq()
|
|
|
|
quant_map = {
|
|
ColumnLinear: FP8Linear,
|
|
RowLinear: FP8RowLinear,
|
|
MixtureOfExperts: MixtureOfExperts,
|
|
}
|
|
|
|
model = quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
)
|
|
return model
|
|
|
|
|
|
def fp8_rowwise_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_fp8_rowwise()
|
|
|
|
quant_cls_map = {
|
|
RmsNorm: Fp8RowwiseRmsNorm,
|
|
GatedMLP: Fp8RowwiseGatedMLP,
|
|
MLP: Fp8RowwiseMLP,
|
|
Attention: Fp8RowwiseAttention,
|
|
}
|
|
|
|
if quant_config.exclude_modules is None:
|
|
exclude_modules = ['*ln_f', '*ln_embed']
|
|
else:
|
|
exclude_modules = quant_config.exclude_modules
|
|
|
|
def extract_layer_idx(name):
|
|
ss = name.split('.')
|
|
for s in ss:
|
|
if s.isdigit():
|
|
return int(s)
|
|
return None
|
|
|
|
# Meta's LLaMA 3.1 recipe:
|
|
# (1) Skip quantization for the first and last Transformer layers
|
|
# (2) Skip quantization for the Attention layers
|
|
if quant_config.use_meta_recipe:
|
|
exclude_modules.extend(['*input_layernorm', '*attention'])
|
|
|
|
for name, layer, parent in model.named_modules_with_parent():
|
|
module_name = name.rsplit('.', 1)[-1]
|
|
|
|
if quant_config.use_meta_recipe:
|
|
local_layer_idx = extract_layer_idx(name)
|
|
mapping = model.config.mapping
|
|
layers_range = mapping.pp_layers(model.config.num_hidden_layers)
|
|
if mapping.is_first_pp_rank() and local_layer_idx == 0:
|
|
continue
|
|
if mapping.is_last_pp_rank(
|
|
) and local_layer_idx == len(layers_range) - 1:
|
|
continue
|
|
|
|
quant_cls = None
|
|
for cls in quant_cls_map:
|
|
if isinstance(layer, cls):
|
|
quant_cls = quant_cls_map[cls]
|
|
break
|
|
if quant_cls is None:
|
|
continue
|
|
|
|
is_excluded = False
|
|
for exclude_module in exclude_modules:
|
|
if fnmatch.fnmatchcase(name, exclude_module):
|
|
is_excluded = True
|
|
break
|
|
if is_excluded:
|
|
continue
|
|
|
|
init_params = get_init_params(layer, quant_cls)
|
|
init_params["quant_mode"] = quant_config.quant_mode
|
|
if isinstance(layer, Attention):
|
|
init_params[
|
|
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
|
|
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
|
|
if parent is not None:
|
|
setattr(parent, module_name, quant_layer)
|
|
else:
|
|
model = quant_layer
|
|
|
|
setattr(model, 'quant_mode', quant_config.quant_mode)
|
|
return model
|
|
|
|
|
|
# TODO: These functions should be moved to ModelOpt.
|
|
def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor,
|
|
s1_scales: torch.FloatTensor,
|
|
s2_scales: torch.FloatTensor,
|
|
s2_szeros: torch.FloatTensor,
|
|
group_size: int) -> torch.CharTensor:
|
|
out_features = linear_weight.shape[0]
|
|
in_features = linear_weight.shape[1]
|
|
|
|
# Step 1: Quantize the weights to int8
|
|
linear_weight = linear_weight.div(
|
|
s1_scales.reshape(out_features, 1).to(linear_weight.device))
|
|
linear_weight = linear_weight.round()
|
|
# assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number
|
|
assert (linear_weight.min() >= -128 and linear_weight.max()
|
|
<= 127), "Stage 1: Quantized weight out of range"
|
|
|
|
# Step 2: Quantize the weights to int4
|
|
linear_weight = linear_weight.reshape(out_features,
|
|
in_features // group_size, group_size)
|
|
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
|
|
1).to(torch.float16).to(linear_weight.device)
|
|
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
|
|
1).to(torch.float16).to(linear_weight.device)
|
|
linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round()
|
|
assert (linear_weight.min() >= 0 and linear_weight.max()
|
|
<= 15), "Stage 2: Quantized weight out of range"
|
|
|
|
qweight = linear_weight.reshape(out_features, in_features).to(torch.int8)
|
|
return qweight
|
|
|
|
|
|
def qserve_quantize_weight_per_channel(
|
|
linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor,
|
|
s1_szeros: torch.FloatTensor) -> torch.CharTensor:
|
|
out_features = linear_weight.shape[0]
|
|
in_features = linear_weight.shape[1]
|
|
|
|
# Step 1: Quantize the weights to int4
|
|
s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device)
|
|
s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device)
|
|
|
|
qweight = linear_weight.add(s1_szeros).div(s1_scales).round()
|
|
assert (qweight.min() >= 0
|
|
and qweight.max() <= 15), "Quantized weight out of range"
|
|
|
|
return qweight.reshape(out_features, in_features).to(torch.int8)
|
|
|
|
|
|
# Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels.
|
|
# Return: processed [qweight, s1_scales, s2_scales, s2_zeros]
|
|
def qserve_pack_reorder_per_group(qweight: torch.CharTensor,
|
|
s1_scales: torch.FloatTensor,
|
|
s2_scales: torch.FloatTensor,
|
|
s2_szeros: torch.FloatTensor, group_size):
|
|
out_features = qweight.shape[0]
|
|
in_features = qweight.shape[1]
|
|
|
|
outputs = []
|
|
|
|
s1_scales = s1_scales.reshape(out_features).to(torch.float16)
|
|
s2_szeros = s2_szeros.reshape(out_features,
|
|
in_features // group_size).to(torch.int8)
|
|
s2_scales = s2_scales.reshape(out_features,
|
|
in_features // group_size).to(torch.int8)
|
|
|
|
# Step 3: Pack the quantized weights to real quantized weights
|
|
# ---- Repack the weight ---- #
|
|
assert qweight.dtype == torch.int8
|
|
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
|
|
W_unpack_reorder = (qweight.reshape(
|
|
out_features // 32,
|
|
2,
|
|
2,
|
|
8,
|
|
in_features // 32,
|
|
2,
|
|
4,
|
|
4,
|
|
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
|
|
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
|
|
4).contiguous().to(torch.int8))
|
|
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
|
|
# [16, 0, 17, 1, ...]
|
|
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
|
|
0]
|
|
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
|
|
in_features // 32, 32, 16)
|
|
W_unpack_repacked = W_unpack_repacked.reshape(out_features,
|
|
in_features // 2)
|
|
|
|
outputs.append(W_unpack_repacked)
|
|
|
|
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
|
|
# ---- Pack the scales ---- #
|
|
outputs.append(s1_scales.reshape(out_features))
|
|
|
|
s2_scales = (s2_scales.reshape(out_features, in_features //
|
|
group_size).transpose(0, 1).contiguous())
|
|
s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32,
|
|
32)
|
|
s2_scales = (s2_scales.reshape(in_features // group_size,
|
|
out_features // 32, 4,
|
|
8).transpose(-2, -1).contiguous())
|
|
s2_scales = s2_scales.reshape(in_features // group_size,
|
|
out_features).contiguous()
|
|
outputs.append(s2_scales)
|
|
|
|
# ---- Pack the zeros ---- #
|
|
s2_szeros = (s2_szeros.reshape(out_features, in_features //
|
|
group_size).transpose(0, 1).contiguous())
|
|
s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32,
|
|
32)
|
|
s2_szeros = (s2_szeros.reshape(in_features // group_size,
|
|
out_features // 32, 4,
|
|
8).transpose(-2, -1).contiguous())
|
|
s2_szeros = (s2_szeros.reshape(in_features // group_size,
|
|
out_features).contiguous())
|
|
|
|
# (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales,
|
|
# We convert the s2_zeros -> -s2_zeros * s2_scales
|
|
s2_szeros = (-s2_szeros).int() # It has been pre-scaled in DeepCompressor
|
|
s2_szeros = s2_szeros.to(torch.int8)
|
|
|
|
outputs.append(s2_szeros)
|
|
|
|
return outputs
|
|
|
|
|
|
def qserve_pack_reorder_per_channel(qweight: torch.CharTensor,
|
|
s1_scales: torch.FloatTensor,
|
|
s1_szeros: torch.FloatTensor):
|
|
out_features = qweight.shape[0]
|
|
in_features = qweight.shape[1]
|
|
|
|
outputs = []
|
|
|
|
# ---- Repack the weight ---- #
|
|
assert qweight.dtype == torch.int8
|
|
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
|
|
W_unpack_reorder = (qweight.reshape(
|
|
out_features // 32,
|
|
2,
|
|
2,
|
|
8,
|
|
in_features // 32,
|
|
2,
|
|
4,
|
|
4,
|
|
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
|
|
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
|
|
4).contiguous())
|
|
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
|
|
# [16, 0, 17, 1, ...]
|
|
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
|
|
0]
|
|
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
|
|
in_features // 32, 32, 16)
|
|
W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features //
|
|
2).contiguous()
|
|
|
|
outputs.append(W_unpack_repacked)
|
|
|
|
# ---- Pack the scales and zeros ---- #
|
|
s1_scales = s1_scales.reshape(out_features).contiguous()
|
|
outputs.append(s1_scales.half())
|
|
|
|
s1_szeros = s1_szeros.reshape(out_features).contiguous().half()
|
|
outputs.append(s1_szeros)
|
|
|
|
return outputs
|
|
|
|
|
|
# TODO: Duplicates smooth_quantize and quantize_layers
|
|
def qserve_quantize(model, quant_config: QuantConfig):
|
|
quant_mode = quant_config.quant_mode
|
|
assert quant_config.quant_mode.is_qserve_w4a8()
|
|
|
|
quant_map = {
|
|
RmsNorm: QServeRmsNorm,
|
|
LayerNorm: QServeRmsNorm,
|
|
GatedMLP: QServeGatedMLP,
|
|
MLP: QServeMLP,
|
|
Attention: QServeAttention,
|
|
}
|
|
|
|
for name, layer, parent in model.named_modules_with_parent():
|
|
layer_name = name.rsplit('.', 1)[-1]
|
|
if layer_name in ['ln_f', 'ln_embed']:
|
|
continue
|
|
|
|
quant_cls = None
|
|
for cls in quant_map:
|
|
if isinstance(layer, cls):
|
|
quant_cls = quant_map[cls]
|
|
break
|
|
|
|
if quant_cls is None:
|
|
continue
|
|
|
|
init_params = get_init_params(layer, quant_cls)
|
|
init_params["quant_mode"] = quant_mode
|
|
if isinstance(layer, Attention):
|
|
init_params[
|
|
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
|
|
quant_layer = quant_cls(**init_params)
|
|
if parent is not None:
|
|
setattr(parent, layer_name, quant_layer)
|
|
else:
|
|
model = quant_layer
|
|
|
|
setattr(model, 'quant_mode', quant_mode)
|
|
return model
|
|
|
|
|
|
def fp4_quantize(model, quant_config: QuantConfig):
|
|
assert quant_config.quant_mode.has_nvfp4()
|
|
quant_map = {
|
|
ColumnLinear: FP4Linear,
|
|
RowLinear: FP4RowLinear,
|
|
MixtureOfExperts: MixtureOfExperts,
|
|
}
|
|
|
|
model = quantize_layers(
|
|
model,
|
|
quant_config,
|
|
quant_map,
|
|
)
|
|
return model
|
|
|
|
|
|
# Now consider the kv cache is enabled for all layers
|
|
def kv_cache_quantize(model):
|
|
for name, module in model.named_modules():
|
|
if isinstance(module,
|
|
(Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
|
|
# for dequant
|
|
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
|
|
dtype='float32')
|
|
# for quant
|
|
module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ),
|
|
dtype='float32')
|
|
return model
|
|
|
|
|
|
def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
|
|
|
|
for name, module, parent in model.named_modules_with_parent():
|
|
|
|
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
|
|
layer_quant_mode = quant_config.layer_quant_mode(name)
|
|
else:
|
|
layer_quant_mode = quant_config.layer_quant_mode
|
|
if layer_quant_mode == QuantMode(0):
|
|
continue
|
|
|
|
layer_quant_cfg = quant_config._get_quant_cfg(name)
|
|
|
|
if layer_quant_mode.has_fp8_qdq():
|
|
module = fp8_quantize(module, layer_quant_cfg)
|
|
elif layer_quant_mode.has_fp8_rowwise():
|
|
module = fp8_rowwise_quantize(module, layer_quant_cfg)
|
|
elif layer_quant_mode.is_qserve_w4a8():
|
|
module = qserve_quantize(module, quant_config)
|
|
elif layer_quant_mode.has_nvfp4():
|
|
module = fp4_quantize(module, layer_quant_cfg)
|
|
elif layer_quant_mode.has_act_and_weight_quant():
|
|
module = smooth_quantize(module, layer_quant_cfg)
|
|
elif layer_quant_mode.is_weight_only():
|
|
if layer_quant_mode.has_per_group_scaling():
|
|
module = weight_only_groupwise_quantize(module, layer_quant_cfg,
|
|
model.config)
|
|
else:
|
|
module = weight_only_quantize(module, layer_quant_cfg,
|
|
model.config)
|
|
|
|
if parent is not None: # for per layer
|
|
module_name = name.rsplit('.', 1)[-1]
|
|
setattr(parent, module_name, module)
|
|
else: # for all layer
|
|
model = module
|
|
break
|
|
|
|
if quant_config.quant_mode.has_kv_cache_quant():
|
|
model = kv_cache_quantize(model)
|
|
|
|
setattr(model, 'quant_mode', quant_config.quant_mode)
|
|
return model
|