TensorRT-LLMs/tensorrt_llm/quantization/quantize.py
Dom Brown 8709fe8b53
chore: bump version to 0.19.0 (#3598) (#3841)
test: add test cases for 0.19 release (#3608)

* fix test name



* add quickstart test for nemotron-ultra



* add rcca multi-node test case for deepseek-v3



* add rcca info



---------




squash (#3642)



fix: nvbugs/5187237: fix deterministic mode crash (#3448)

* nvbugs/5187237 nvbugs/5112075: fix deterministic mode error

* remove waive


* Revert "remove waive"

This reverts commit 0bf5486d19906d692bfb7a6262333c296b0087ac.



* revert ar fusion



---------



update fp8 doc (#3647)




tests: change qa perf test to trtllm-bench (#3619)




 fix: FP8 quantized lm_head (NvBug 5214229) (#3567)



infra: Add PR approval protection for the release branch (#3634)



fix: nvbugs/5231298: pytorch allreduce issue (#3673)



Fix: nvbugs/5222698 variable not defined (#3630)

* Fix: nvbugs/5222698 variable not defined



* Tidy code



---------



test:sync waives.txt from main branch by disabling test_perf/gpt_350m-cppmanager case (#3685)



test:restore fp8 kv cache testing for L0 (#3671)



doc: Update DeepSeek perf docs (#3693)

* Update DeepSeek perf docs



* update



* Apply suggestions from code review




---------




tests: waive test_llm_multi_node (#3664)



fix: update test_user_buffers_mm_add_prologue atol (#3711)



Fix: cherry-pick hmac encryption from main branch (#3635)

* security fix cherry-pick changes from main



* fix hmac in remote mpi session (#3649)



---------





Un-waive DS-V3-Lite tests. (#3621)



fix: FP8 kv accuracy (#3675)

* fix FP8 kv accuracy



* update doc



---------



Fix script options for engines. (#3622)



unwaive multi-node test (#3721)



chore : Split more tests out of gpt tests (#3524) (#3674)



doc:add torch examples link into torch backend documentation (#3749)




test: Get Eagle tests working (#3593) (#3722)




Waive L0 test (#3756)



waive failed case in perf test, change default max_batch_size to 512 and write config.json to output log (#3656)





Update ds v3 parameters in stress test. (#3676)

waive gemma on L20 (#3766)



https://nvbugs/5141291: Fix convert.py script for Qwen model. (#3758)

Include Qwen2VLDecoderLayer in the smooth_qwen2_model function.



fix: PP4 fixes and cleanup (#3688)




remove benchmark test list (#3643)



skip disagg deepseek test if sm!=90 (#3720)



test: skip failed cases on B200 (#3710)

* add skip condition to tests



* fix error



---------



test: [nvbug: 5234494] skip_pre_ada for fp8 cases (#3718)

* skip_pre_ada for fp8 cases



* update



* update after rebase



---------



add know issue to deepseek doc. (#3800)



Fix ModelOpt Mixtral AWQ OOM (#3714) (#3761)




Waive L0 tests (#3826)



fix: Reduce memory usage in fused moe op associated with AutoTuning and fix moe fallback issue. (#3793)

* Reduce memory usage in fused moe op associated with AutoTuning.
* Replace pre-defined bucket size strategy with a generating function based on the tune_max_num_tokens.
* Add free_memory logic of workspace in min_latency_mode fused moe path.



* Fix fused_moe fallback issue. (#3652)

min_latency_mode is only set to False during warmup phase. Thus when it becomes true during inference, all tactics fall back to the default one and thus cause perf regression.



---------



[doc] Better document for Draft-Target-Model (DTM) speculative decoding (#3797)




Fix pre-commit



Fix again



Address some review comments for the MI

Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
2025-04-29 16:57:22 +08:00

601 lines
22 KiB
Python

import fnmatch
from typing import Union
import torch
from .._utils import get_init_params
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
LayerNorm, RmsNorm, RowLinear)
from ..layers.moe import MixtureOfExperts
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
from ..parameter import Parameter
from .layers import (FP4Linear, FP4RowLinear, FP8Linear, FP8RowLinear,
Fp8RowwiseAttention, Fp8RowwiseGatedMLP, Fp8RowwiseMLP,
Fp8RowwiseRmsNorm, Int8SmoothQuantLinear,
Int8SmoothQuantRowLinear, QServeAttention, QServeGatedMLP,
QServeMLP, QServeRmsNorm, SmoothQuantAttention,
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear,
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding,
WeightOnlyQuantRowLinear)
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode
def quantize_layers(
model,
quant_config: QuantConfig,
quant_map,
preprocess_init_params=None,
):
exclude_modules = quant_config.exclude_modules
if exclude_modules is None:
exclude_modules = [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
for name, module, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
is_excluded = False
quant_cls = None
# handle exclusion
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
# MoE modules are quantized on their constructor, so they must always
# be re-created with the appropriate quant_mode. When excluded,
# re-create with quant_mode 0.
# We need to handle it specially, we may want to redesign MoE implementation
if isinstance(module, MixtureOfExperts):
quant_cls = type(module)
elif not is_excluded:
for cls in quant_map:
if isinstance(module, cls):
quant_cls = quant_map[cls]
break
if quant_cls:
init_params = get_init_params(module, quant_cls)
if isinstance(module, MixtureOfExperts):
if is_excluded:
quant_mode = QuantMode(0)
else:
quant_mode = quant_config.quant_mode
init_params["quant_mode"] = quant_mode
if "bias" in init_params and not isinstance(module,
MixtureOfExperts):
init_params["bias"] = init_params["bias"] is not None
if isinstance(module, ColumnLinear):
init_params[
"out_features"] = module.out_features * module.tp_size
elif isinstance(module, RowLinear):
init_params["in_features"] = module.in_features * module.tp_size
if preprocess_init_params is not None:
preprocess_init_params(init_params, name, module)
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyQuantColumnLinear,
RowLinear: WeightOnlyQuantRowLinear,
Embedding: WeightOnlyQuantEmbedding,
}
def preprocess_init_params(init_params, name, module):
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(module, ColumnLinear):
module_name = name.rsplit('.', 1)[-1]
init_params["transb"] = module_name == "lm_head"
if "tp_rank" in init_params:
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def weight_only_groupwise_quantize(model,
quant_config: QuantConfig,
model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
MixtureOfExperts: MixtureOfExperts,
}
def preprocess_init_params(init_params, name, module):
init_params["group_size"] = quant_config.group_size
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
init_params["zero"] = quant_config.has_zero_point
init_params[
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
init_params[
"use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ
if "tp_rank" in init_params:
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def smooth_quantize_ootb(
model,
quant_config: QuantConfig,
):
quant_map = {
ColumnLinear: Int8SmoothQuantLinear,
RowLinear: Int8SmoothQuantRowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def smooth_quantize_plugin(model, quant_mode):
quant_map = {
RmsNorm: SmoothQuantRmsNorm,
LayerNorm: SmoothQuantLayerNorm,
GatedMLP: SmoothQuantGatedMLP,
MLP: SmoothQuantMLP,
Attention: SmoothQuantAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def smooth_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_act_and_weight_quant()
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
return smooth_quantize_plugin(model, quant_config.quant_mode)
else:
return smooth_quantize_ootb(model, quant_config)
def fp8_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_qdq()
quant_map = {
ColumnLinear: FP8Linear,
RowLinear: FP8RowLinear,
MixtureOfExperts: MixtureOfExperts,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def fp8_rowwise_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_rowwise()
quant_cls_map = {
RmsNorm: Fp8RowwiseRmsNorm,
GatedMLP: Fp8RowwiseGatedMLP,
MLP: Fp8RowwiseMLP,
Attention: Fp8RowwiseAttention,
}
exclude_modules = quant_config.exclude_modules
if exclude_modules is None:
exclude_modules = []
# Always exclude these modules for FP8 rowwise
exclude_modules = list(
set(exclude_modules + ['*ln_f', '*ln_embed', '*lm_head']))
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return int(s)
return None
# Meta's LLaMA 3.1 recipe:
# (1) Skip quantization for the first and last Transformer layers
# (2) Skip quantization for the Attention layers
if quant_config.use_meta_recipe:
exclude_modules.extend(['*input_layernorm', '*attention'])
for name, layer, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
if quant_config.use_meta_recipe:
local_layer_idx = extract_layer_idx(name)
mapping = model.config.mapping
layers_range = mapping.pp_layers(model.config.num_hidden_layers)
if mapping.is_first_pp_rank() and local_layer_idx == 0:
continue
if mapping.is_last_pp_rank(
) and local_layer_idx == len(layers_range) - 1:
continue
quant_cls = None
for cls in quant_cls_map:
if isinstance(layer, cls):
quant_cls = quant_cls_map[cls]
break
if quant_cls is None:
continue
is_excluded = False
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
if is_excluded:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
# TODO: These functions should be moved to ModelOpt.
def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor,
group_size: int) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int8
linear_weight = linear_weight.div(
s1_scales.reshape(out_features, 1).to(linear_weight.device))
linear_weight = linear_weight.round()
# assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number
assert (linear_weight.min() >= -128 and linear_weight.max()
<= 127), "Stage 1: Quantized weight out of range"
# Step 2: Quantize the weights to int4
linear_weight = linear_weight.reshape(out_features,
in_features // group_size, group_size)
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round()
assert (linear_weight.min() >= 0 and linear_weight.max()
<= 15), "Stage 2: Quantized weight out of range"
qweight = linear_weight.reshape(out_features, in_features).to(torch.int8)
return qweight
def qserve_quantize_weight_per_channel(
linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int4
s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device)
s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device)
qweight = linear_weight.add(s1_szeros).div(s1_scales).round()
assert (qweight.min() >= 0
and qweight.max() <= 15), "Quantized weight out of range"
return qweight.reshape(out_features, in_features).to(torch.int8)
# Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels.
# Return: processed [qweight, s1_scales, s2_scales, s2_zeros]
def qserve_pack_reorder_per_group(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor, group_size):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
s1_scales = s1_scales.reshape(out_features).to(torch.float16)
s2_szeros = s2_szeros.reshape(out_features,
in_features // group_size).to(torch.int8)
s2_scales = s2_scales.reshape(out_features,
in_features // group_size).to(torch.int8)
# Step 3: Pack the quantized weights to real quantized weights
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous().to(torch.int8))
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features,
in_features // 2)
outputs.append(W_unpack_repacked)
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
# ---- Pack the scales ---- #
outputs.append(s1_scales.reshape(out_features))
s2_scales = (s2_scales.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32,
32)
s2_scales = (s2_scales.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size,
out_features).contiguous()
outputs.append(s2_scales)
# ---- Pack the zeros ---- #
s2_szeros = (s2_szeros.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32,
32)
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features).contiguous())
# (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales,
# We convert the s2_zeros -> -s2_zeros * s2_scales
s2_szeros = (-s2_szeros).int() # It has been pre-scaled in DeepCompressor
s2_szeros = s2_szeros.to(torch.int8)
outputs.append(s2_szeros)
return outputs
def qserve_pack_reorder_per_channel(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous())
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features //
2).contiguous()
outputs.append(W_unpack_repacked)
# ---- Pack the scales and zeros ---- #
s1_scales = s1_scales.reshape(out_features).contiguous()
outputs.append(s1_scales.half())
s1_szeros = s1_szeros.reshape(out_features).contiguous().half()
outputs.append(s1_szeros)
return outputs
# TODO: Duplicates smooth_quantize and quantize_layers
def qserve_quantize(model, quant_config: QuantConfig):
quant_mode = quant_config.quant_mode
assert quant_config.quant_mode.is_qserve_w4a8()
quant_map = {
RmsNorm: QServeRmsNorm,
LayerNorm: QServeRmsNorm,
GatedMLP: QServeGatedMLP,
MLP: QServeMLP,
Attention: QServeAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def fp4_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_nvfp4()
quant_map = {
ColumnLinear: FP4Linear,
RowLinear: FP4RowLinear,
MixtureOfExperts: MixtureOfExperts,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
# Now consider the kv cache is enabled for all layers
def kv_cache_quantize(model):
for name, module in model.named_modules():
if isinstance(module,
(Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
# for dequant
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
# for quant
module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
return model
def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
for name, module, parent in model.named_modules_with_parent():
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
layer_quant_mode = quant_config.layer_quant_mode(name)
else:
layer_quant_mode = quant_config.layer_quant_mode
if layer_quant_mode == QuantMode(0):
continue
layer_quant_cfg = quant_config._get_quant_cfg(name)
if layer_quant_mode.has_fp8_qdq():
module = fp8_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_fp8_rowwise():
module = fp8_rowwise_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_qserve_w4a8():
module = qserve_quantize(module, quant_config)
elif layer_quant_mode.has_nvfp4():
module = fp4_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_act_and_weight_quant():
module = smooth_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_weight_only():
if layer_quant_mode.has_per_group_scaling():
module = weight_only_groupwise_quantize(module, layer_quant_cfg,
model.config)
else:
module = weight_only_quantize(module, layer_quant_cfg,
model.config)
if parent is not None: # for per layer
module_name = name.rsplit('.', 1)[-1]
setattr(parent, module_name, module)
else: # for all layer
model = module
break
if quant_config.quant_mode.has_kv_cache_quant():
model = kv_cache_quantize(model)
setattr(model, 'quant_mode', quant_config.quant_mode)
return model