mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
268 lines
9.5 KiB
Python
268 lines
9.5 KiB
Python
import argparse
|
||
import os
|
||
import time
|
||
import traceback
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
import tensorrt_llm
|
||
from tensorrt_llm._utils import release_gc
|
||
from tensorrt_llm.logger import logger
|
||
from tensorrt_llm.mapping import Mapping
|
||
from tensorrt_llm.models import ChatGLMForCausalLM
|
||
from tensorrt_llm.models.chatglm.config import GLM_VERSIONS
|
||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||
from tensorrt_llm.quantization import QuantAlgo
|
||
|
||
|
||
def parse_arguments():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--model_dir', type=str, default=None)
|
||
parser.add_argument(
|
||
'--chatglm_version',
|
||
default=None,
|
||
choices=[None] + GLM_VERSIONS,
|
||
help=
|
||
"By default the script will try to infer the chatglm_version from model_dir. "
|
||
"Or users may overwrite chatglm_version by explicitly passing the version."
|
||
)
|
||
parser.add_argument('--tp_size',
|
||
type=int,
|
||
default=1,
|
||
help='N-way tensor parallelism size')
|
||
parser.add_argument('--pp_size',
|
||
type=int,
|
||
default=1,
|
||
help='N-way pipeline parallelism size')
|
||
parser.add_argument('--cp_size',
|
||
type=int,
|
||
default=1,
|
||
help='N-way context parallelism size')
|
||
parser.add_argument(
|
||
'--dtype',
|
||
type=str,
|
||
default='auto',
|
||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||
help=
|
||
"The data type for the model weights and activations if not quantized. "
|
||
"If 'auto', the data type is automatically inferred from the source model; "
|
||
"however, if the source dtype is float32, it is converted to float16.")
|
||
parser.add_argument('--logits_dtype',
|
||
type=str,
|
||
default='float32',
|
||
choices=['float16', 'float32'])
|
||
parser.add_argument(
|
||
'--use_parallel_embedding',
|
||
action="store_true",
|
||
default=False,
|
||
help=
|
||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||
)
|
||
parser.add_argument(
|
||
'--embedding_sharding_dim',
|
||
type=int,
|
||
default=0,
|
||
choices=[0, 1],
|
||
help=
|
||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--use_weight_only',
|
||
default=False,
|
||
action="store_true",
|
||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||
'See --weight_only_precision to set the precision')
|
||
parser.add_argument(
|
||
'--weight_only_precision',
|
||
const='int8',
|
||
type=str,
|
||
nargs='?',
|
||
default='int8',
|
||
choices=['int8', 'int4'],
|
||
help=
|
||
'Define the precision for the weights when using weight-only quantization.'
|
||
'You must also use --use_weight_only for that argument to have an impact.'
|
||
)
|
||
parser.add_argument(
|
||
'--calib_dataset',
|
||
type=str,
|
||
default='cnn_dailymail',
|
||
help=
|
||
"The huggingface dataset name or the local directory of the dataset for calibration."
|
||
)
|
||
parser.add_argument(
|
||
"--smoothquant",
|
||
"-sq",
|
||
type=float,
|
||
default=None,
|
||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||
" to Smoothquant the model, and output int8 weights."
|
||
" A good first try is 0.5. Must be in [0, 1]")
|
||
parser.add_argument(
|
||
'--per_channel',
|
||
action="store_true",
|
||
default=False,
|
||
help=
|
||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||
'per_channel instead uses a different static scaling factor for each channel. '
|
||
'The latter is usually more accurate, but a little slower.')
|
||
parser.add_argument(
|
||
'--per_token',
|
||
action="store_true",
|
||
default=False,
|
||
help=
|
||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||
'The latter is usually more accurate, but a little slower.')
|
||
parser.add_argument(
|
||
'--int8_kv_cache',
|
||
default=False,
|
||
action="store_true",
|
||
help=
|
||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||
)
|
||
|
||
parser.add_argument('--output_dir',
|
||
type=str,
|
||
default='tllm_checkpoint',
|
||
help='The path to save the TensorRT-LLM checkpoint')
|
||
parser.add_argument(
|
||
'--workers',
|
||
type=int,
|
||
default=1,
|
||
help='The number of workers for converting checkpoint in parallel')
|
||
parser.add_argument('--log_level', type=str, default='info')
|
||
parser.add_argument(
|
||
'--device',
|
||
help=
|
||
"The device to run calibration; effective for HuggingFace model only.",
|
||
default='cuda',
|
||
choices=['cuda', 'cpu'])
|
||
parser.add_argument("--load_model_on_cpu", action="store_true")
|
||
args = parser.parse_args()
|
||
|
||
return args
|
||
|
||
|
||
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
|
||
'''return config dict with quantization info based on the command line args
|
||
'''
|
||
quant_config = QuantConfig()
|
||
|
||
if args.use_weight_only:
|
||
if args.weight_only_precision == 'int8':
|
||
quant_config.quant_algo = QuantAlgo.W8A16
|
||
elif args.weight_only_precision == 'int4':
|
||
quant_config.quant_algo = QuantAlgo.W4A16
|
||
elif args.smoothquant:
|
||
quant_config.smoothquant_val = args.smoothquant
|
||
if args.per_channel:
|
||
if args.per_token:
|
||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
||
else:
|
||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
||
else:
|
||
if args.per_token:
|
||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
||
else:
|
||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
||
|
||
if args.int8_kv_cache:
|
||
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
|
||
|
||
return quant_config
|
||
|
||
|
||
def args_to_build_options(args):
|
||
return {
|
||
'use_parallel_embedding': args.use_parallel_embedding,
|
||
'embedding_sharding_dim': args.embedding_sharding_dim,
|
||
'logits_dtype': args.logits_dtype,
|
||
}
|
||
|
||
|
||
def execute(workers, func):
|
||
if workers == 1:
|
||
for rank, f in enumerate(func):
|
||
f(rank)
|
||
else:
|
||
with ThreadPoolExecutor(max_workers=workers) as p:
|
||
futures = [p.submit(f, rank) for rank, f in enumerate(func)]
|
||
exceptions = []
|
||
for future in as_completed(futures):
|
||
try:
|
||
future.result()
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
exceptions.append(e)
|
||
assert len(
|
||
exceptions
|
||
) == 0, "Checkpoint conversion failed, please check error log."
|
||
|
||
|
||
def convert_and_save_hf(args):
|
||
world_size = args.tp_size * args.pp_size
|
||
quant_config = args_to_quant_config(args)
|
||
override_fields = args_to_build_options(args)
|
||
|
||
if args.smoothquant is not None or args.int8_kv_cache:
|
||
mapping = Mapping(world_size=world_size,
|
||
tp_size=args.tp_size,
|
||
pp_size=args.pp_size,
|
||
cp_size=args.cp_size)
|
||
ChatGLMForCausalLM.quantize(args.model_dir,
|
||
args.output_dir,
|
||
dtype=args.dtype,
|
||
mapping=mapping,
|
||
quant_config=quant_config,
|
||
device=args.device,
|
||
calib_dataset=args.calib_dataset,
|
||
**override_fields)
|
||
else:
|
||
|
||
def convert_and_save_rank(rank):
|
||
mapping = Mapping(world_size=world_size,
|
||
rank=rank,
|
||
tp_size=args.tp_size,
|
||
pp_size=args.pp_size)
|
||
glm = ChatGLMForCausalLM.from_hugging_face(
|
||
args.model_dir,
|
||
args.dtype,
|
||
mapping=mapping,
|
||
quant_config=quant_config,
|
||
chatglm_version=args.chatglm_version,
|
||
load_model_on_cpu=args.load_model_on_cpu,
|
||
**override_fields)
|
||
glm.config.mapping.cp_size = args.cp_size
|
||
glm.config.mapping.world_size *= args.cp_size
|
||
glm.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||
del glm
|
||
|
||
execute(args.workers, [convert_and_save_rank] * world_size)
|
||
release_gc()
|
||
|
||
|
||
def main():
|
||
print(tensorrt_llm.__version__)
|
||
args = parse_arguments()
|
||
logger.set_level(args.log_level)
|
||
|
||
assert args.pp_size == 1, "Pipeline parallelism is not supported."
|
||
|
||
tik = time.time()
|
||
|
||
if not os.path.exists(args.output_dir):
|
||
os.makedirs(args.output_dir)
|
||
|
||
convert_and_save_hf(args)
|
||
|
||
tok = time.time()
|
||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||
print(f'Total time of converting checkpoints: {t}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|