TensorRT-LLMs/examples/quantization/quantize.py
Kaiyu Xie 8681b3a4c0
open source 4dbf696ae9b74a26829d120b67ab8443d70c8e58 (#2297)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
2024-10-08 12:19:19 +02:00

184 lines
7.3 KiB
Python

import argparse
import torch.multiprocessing as mp
from tensorrt_llm.quantization import (quantize_and_export,
quantize_nemo_and_export)
mp.set_start_method("spawn", force=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_dir",
help="Specify where the HuggingFace model is",
default=None)
parser.add_argument('--nemo_ckpt_path',
help="Specify where the NeMo checkpoint is",
default=None)
parser.add_argument(
'--decoder_type',
type=str,
default='gptnext',
choices=['gptnext', 'llama'],
help="Decoder type; effective for NeMo checkpoint only.")
parser.add_argument(
'--device',
help=
"The device to run calibration; effective for HuggingFace model only.",
default='cuda',
choices=['cuda', 'cpu'])
parser.add_argument(
'--calib_dataset',
type=str,
default='cnn_dailymail',
help=
"The huggingface dataset name or the local directory of the dataset for calibration."
)
parser.add_argument(
'--calib_tp_size',
type=int,
default=1,
help=
"Tensor parallel size for calibration; effective for NeMo checkpoint only."
)
parser.add_argument(
'--calib_pp_size',
type=int,
default=1,
help=
"Pipeline parallel size for calibration; effective for NeMo checkpoint only."
)
parser.add_argument("--dtype", help="Model data type.", default="float16")
parser.add_argument(
"--qformat",
help="Quantization format.",
default="full_prec",
choices=[
"fp8",
"int8_sq",
"int4_awq",
"w4a8_awq",
"int8_wo",
"int4_wo",
"full_prec",
],
)
parser.add_argument(
"--seed",
help="Seed the generate random numbers, the value will be used to call"
"random.seed(value) and numpy.random.seed(value)",
type=int,
default=1234)
parser.add_argument("--tokenizer_max_seq_length",
help="Max sequence length to init the tokenizers",
type=int,
default=2048)
parser.add_argument("--batch_size",
help="Batch size for calibration.",
type=int,
default=1)
parser.add_argument("--calib_size",
help="Number of samples for calibration.",
type=int,
default=512)
parser.add_argument("--calib_max_seq_length",
help="Max sequence length for calibration",
type=int,
default=512)
parser.add_argument("--output_dir", default="exported_model")
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--awq_block_size", type=int, default=128)
parser.add_argument("--kv_cache_dtype",
help="KV Cache dtype.",
default=None,
choices=["int8", "fp8", None])
# Medusa
parser.add_argument('--num_medusa_heads', type=int, default=4)
parser.add_argument('--num_medusa_layers', type=int, default=1)
parser.add_argument('--max_draft_len', type=int, default=63)
parser.add_argument('--medusa_hidden_act', type=str, default="silu")
parser.add_argument('--medusa_model_dir', type=str, default=None)
parser.add_argument('--quant_medusa_head',
default=False,
action='store_true',
help="whether to quantize the weights of medusa heads")
# auto quantization
parser.add_argument(
'--autoq_format',
default=None,
type=str,
help=
"Specific quantization algorithms will be searched in auto quantization."
"The algorithm must in ['fp8', 'int4_awq', 'w4a8_awq', 'int8_sq']."
"You can use ',' to separate more than one quantization algorithms(e.g. --autoq_format fp8,int4_awq,w4a8_awq)."
"Notice: fp8 and int8_sq can't be used at the same time.")
parser.add_argument(
'--weight_compression',
type=float,
default=None,
help="Percent of compression size when using mix precision quantization."
"The range is [0.0, 1.0], if you only indicate the autoq_format, it will be default to the lowest possible value."
)
args = parser.parse_args()
# weight_compression check
if args.autoq_format:
lower_bound = 0.25 if '4' in args.autoq_format else 0.5
if args.weight_compression is None or args.weight_compression < lower_bound:
print(
f"invalid weight_compression value, will be set to {lower_bound}"
)
args.weight_compression = lower_bound
if args.model_dir is not None:
quantize_and_export(
model_dir=args.model_dir,
device=args.device,
calib_dataset=args.calib_dataset,
dtype=args.dtype,
qformat=args.qformat
if args.weight_compression is None else args.autoq_format,
kv_cache_dtype=args.kv_cache_dtype,
calib_size=args.calib_size,
batch_size=args.batch_size,
calib_max_seq_length=args.calib_max_seq_length,
awq_block_size=args.awq_block_size,
output_dir=args.output_dir,
tp_size=args.tp_size,
pp_size=args.pp_size,
seed=args.seed,
tokenizer_max_seq_length=args.tokenizer_max_seq_length,
num_medusa_heads=args.num_medusa_heads,
num_medusa_layers=args.num_medusa_layers,
max_draft_len=args.max_draft_len,
medusa_hidden_act=args.medusa_hidden_act,
medusa_model_dir=args.medusa_model_dir,
quant_medusa_head=args.quant_medusa_head,
weight_compression=args.weight_compression)
elif args.nemo_ckpt_path is not None:
quantize_nemo_and_export(nemo_ckpt_path=args.nemo_ckpt_path,
decoder_type=args.decoder_type,
calib_dataset=args.calib_dataset,
calib_tp_size=args.calib_tp_size,
calib_pp_size=args.calib_pp_size,
dtype=args.dtype,
qformat=args.qformat,
kv_cache_dtype=args.kv_cache_dtype,
calib_size=args.calib_size,
batch_size=args.batch_size,
calib_max_seq_length=args.calib_max_seq_length,
awq_block_size=args.awq_block_size,
output_dir=args.output_dir,
tp_size=args.tp_size,
pp_size=args.pp_size,
seed=args.seed)
else:
raise ValueError(
"One of source checkpoint (model_dir, nemo_ckpt_path) must be specified"
)