TensorRT-LLMs/examples/quantization/quantize.py
Aurelien Chartier 1389f5a4d3
feat: Add support for fp8 rowwise quantization (#4876)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Co-authored-by: aikitoria <151776613+aikitoria@users.noreply.github.com>
2025-06-14 06:37:48 -07:00

209 lines
8.4 KiB
Python

import argparse
import torch.multiprocessing as mp
from tensorrt_llm.quantization import (quantize_and_export,
quantize_nemo_and_export)
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_dir",
help="Specify where the HuggingFace model is",
default=None)
parser.add_argument('--nemo_ckpt_path',
help="Specify where the NeMo checkpoint is",
default=None)
parser.add_argument(
'--decoder_type',
type=str,
default='gptnext',
choices=['gptnext', 'llama'],
help="Decoder type; effective for NeMo checkpoint only.")
parser.add_argument(
'--device',
help=
"The device to run calibration; effective for HuggingFace model only.",
default='cuda',
choices=['cuda', 'cpu'])
parser.add_argument(
"--device_map",
help="How to map the model on the devices",
default="auto",
choices=["auto", "sequential", "cpu", "gpu"],
)
parser.add_argument(
'--calib_dataset',
type=str,
default='cnn_dailymail',
help=
"The huggingface dataset name or the local directory of the dataset for calibration."
)
parser.add_argument(
'--calib_tp_size',
type=int,
default=1,
help=
"Tensor parallel size for calibration; effective for NeMo checkpoint only."
)
parser.add_argument(
'--calib_pp_size',
type=int,
default=1,
help=
"Pipeline parallel size for calibration; effective for NeMo checkpoint only."
)
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=
"The data type for the model weights and activations of the non-quantized part, e.g., embedding and lm_head. "
"If 'auto', the data type is automatically inferred from the source model; "
"however, if the source dtype is float32, it is converted to float16.")
parser.add_argument(
"--qformat",
help="Quantization format.",
default="full_prec",
choices=[
"nvfp4",
"fp8",
"fp8_pc_pt",
"int8_sq",
"int4_awq",
"w4a8_awq",
"int8_wo",
"int4_wo",
"full_prec",
],
)
parser.add_argument(
"--seed",
help="Seed the generate random numbers, the value will be used to call"
"random.seed(value) and numpy.random.seed(value)",
type=int,
default=1234)
parser.add_argument("--tokenizer_max_seq_length",
help="Max sequence length to init the tokenizers",
type=int,
default=2048)
parser.add_argument("--batch_size",
help="Batch size for calibration.",
type=int,
default=1)
parser.add_argument("--calib_size",
help="Number of samples for calibration.",
type=int,
default=512)
parser.add_argument("--calib_max_seq_length",
help="Max sequence length for calibration",
type=int,
default=512)
parser.add_argument("--output_dir", default="exported_model")
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--cp_size", type=int, default=1)
parser.add_argument("--awq_block_size", type=int, default=128)
parser.add_argument("--kv_cache_dtype",
help="KV Cache dtype.",
default=None,
choices=["int8", "fp8", None])
parser.add_argument("--quantize_lm_head",
action='store_true',
default=False)
# Medusa
parser.add_argument('--num_medusa_heads', type=int, default=4)
parser.add_argument('--num_medusa_layers', type=int, default=1)
parser.add_argument('--max_draft_len', type=int, default=63)
parser.add_argument('--medusa_hidden_act', type=str, default="silu")
parser.add_argument('--medusa_model_dir', type=str, default=None)
parser.add_argument('--quant_medusa_head',
default=False,
action='store_true',
help="whether to quantize the weights of medusa heads")
# auto quantization
parser.add_argument(
'--autoq_format',
default=None,
type=str,
help=
"Specific quantization algorithms will be searched in auto quantization."
"The algorithm must in ['fp8', 'int4_awq', 'w4a8_awq', 'int8_sq']."
"You can use ',' to separate more than one quantization algorithms(e.g. --autoq_format fp8,int4_awq,w4a8_awq)."
"Notice: fp8 and int8_sq can't be used at the same time.")
parser.add_argument(
'--auto_quantize_bits',
type=float,
default=None,
help="Effective bits constraint for auto quantization. If not set, "
"regular quantization without auto quantization search will be applied."
"You can't set it lower than the num_bits of most aggressive quantization format."
"For example, if 'int4_awq' is in autoq_format, it can't be lower than 4.0."
)
args = parser.parse_args()
# auto_quantize_bits check
if args.autoq_format:
lower_bound, upper_bound = 4 if '4' in args.autoq_format else 8, 16
if args.auto_quantize_bits is None or args.auto_quantize_bits < lower_bound or args.auto_quantize_bits > upper_bound:
print(
f"invalid auto_quantize_bits value, will be set to {lower_bound}"
)
args.auto_quantize_bits = lower_bound
if args.model_dir is not None:
quantize_and_export(
model_dir=args.model_dir,
device=args.device,
calib_dataset=args.calib_dataset,
dtype=args.dtype,
qformat=args.qformat
if args.auto_quantize_bits is None else args.autoq_format,
kv_cache_dtype=args.kv_cache_dtype,
calib_size=args.calib_size,
batch_size=args.batch_size,
calib_max_seq_length=args.calib_max_seq_length,
awq_block_size=args.awq_block_size,
output_dir=args.output_dir,
tp_size=args.tp_size,
pp_size=args.pp_size,
cp_size=args.cp_size,
seed=args.seed,
tokenizer_max_seq_length=args.tokenizer_max_seq_length,
num_medusa_heads=args.num_medusa_heads,
num_medusa_layers=args.num_medusa_layers,
max_draft_len=args.max_draft_len,
medusa_hidden_act=args.medusa_hidden_act,
medusa_model_dir=args.medusa_model_dir,
quant_medusa_head=args.quant_medusa_head,
auto_quantize_bits=args.auto_quantize_bits,
device_map=args.device_map,
quantize_lm_head=args.quantize_lm_head)
elif args.nemo_ckpt_path is not None:
quantize_nemo_and_export(nemo_ckpt_path=args.nemo_ckpt_path,
decoder_type=args.decoder_type,
calib_dataset=args.calib_dataset,
calib_tp_size=args.calib_tp_size,
calib_pp_size=args.calib_pp_size,
dtype=args.dtype,
qformat=args.qformat,
kv_cache_dtype=args.kv_cache_dtype,
calib_size=args.calib_size,
batch_size=args.batch_size,
calib_max_seq_length=args.calib_max_seq_length,
awq_block_size=args.awq_block_size,
output_dir=args.output_dir,
tp_size=args.tp_size,
pp_size=args.pp_size,
cp_size=args.cp_size,
seed=args.seed)
else:
raise ValueError(
"One of source checkpoint (model_dir, nemo_ckpt_path) must be specified"
)