TensorRT-LLMs/examples/gemma/convert_checkpoint.py
Kaiyu Xie 31ac30e928
Update TensorRT-LLM (#2215)
* Update TensorRT-LLM

---------

Co-authored-by: Sherlock Xu <65327072+Sherlock113@users.noreply.github.com>
2024-09-10 18:21:22 +08:00

269 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import argparse
import time
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Type
import tensorrt_llm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.gemma.config import GEMMA_ARCHITECTURE, GemmaConfig
from tensorrt_llm.models.gemma.convert import (HfParser, JAXParser, KerasParser,
Parsers, QuantizeModifiers,
TorchParser, load_gemma_weights,
non_modelopt_quantize_if_needed)
from tensorrt_llm.models.gemma.model import GemmaForCausalLM
from tensorrt_llm.models.modeling_utils import (QuantConfig, save_checkpoint,
save_config)
from tensorrt_llm.quantization import QuantAlgo
class CheckpointType(str, Enum):
jax = "jax"
keras = "keras"
torch = "torch"
hf = "hf"
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-type",
type=CheckpointType,
choices=list(CheckpointType))
parser.add_argument("--model-dir", type=Path, required=True)
parser.add_argument("--output-model-dir", type=Path, required=True)
parser.add_argument("--world-size",
type=int,
default=1,
help="world size, only support tensor parallelism now")
parser.add_argument(
"--use-weight-only-with-precision",
choices=["int8", "int4", "w4a8_awq", "w4a16_awq"],
help=
"help='Quantize weights for the various GEMMs to INT4/INT8. Define the precision for the weights.",
)
parser.add_argument(
"--use-int8-weight-only-embedding",
action="store_true",
help=
"Use weight only on embedding table and lm_head. (Only supported on Hopper GPU)",
)
parser.add_argument("--dtype",
type=str,
choices=["float32", "bfloat16", "float16"])
parser.add_argument(
"--enable_fp8",
action="store_true",
help="Use FP8 Linear layer for Attention QKV/Dense and MLP.")
parser.add_argument(
"--fp8_kv_cache",
action="store_true",
help=
"By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV",
)
parser.add_argument(
"--quant_ckpt_path",
default=None,
help=
"Path of a directory to quantized model checkpoints in .safetensors format or \
path of a quantized model checkpoint in .npz format")
parser.add_argument(
'--calib_dataset',
type=str,
default='ccdv/cnn_dailymail',
help=
"The huggingface dataset name or the local directory of the dataset for calibration."
)
parser.add_argument('--use_smooth_quant',
action="store_true",
help="Use smooth quant.")
parser.add_argument(
"--int8_kv_cache",
"--calibrate_kv_cache",
"-kv",
action="store_true",
help=
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
)
parser.add_argument(
'--per_channel',
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
"--smoothquant",
"--use_smooth_quant_plugin",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]")
parser.add_argument(
'--tokenizer_dir',
default=None,
help='tokenizer path; defaults to jax_model_dir if left unspecified')
parser.add_argument("--load_model_on_cpu", action="store_true")
args = parser.parse_args()
return args
CKPT_PARSER: Dict[CheckpointType, Type[Parsers]] = {
CheckpointType.jax: JAXParser,
CheckpointType.keras: KerasParser,
CheckpointType.torch: TorchParser,
CheckpointType.hf: HfParser
}
def compute_quant_algo(args: argparse.Namespace) -> Optional[QuantAlgo]:
if args.use_weight_only_with_precision:
return {
"int8": QuantAlgo.W8A16,
"int4": QuantAlgo.W4A16,
"w4a8_awq": QuantAlgo.W4A8_AWQ,
"w4a16_awq": QuantAlgo.W4A16_AWQ,
}[args.use_weight_only_with_precision]
elif args.enable_fp8:
return QuantAlgo.FP8
if args.use_smooth_quant:
return QuantAlgo.W8A8_SQ_PER_CHANNEL
elif args.smoothquant is not None:
if args.per_token and args.per_channel:
return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
elif not args.per_token and not args.per_channel:
return QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
elif not args.per_token and args.per_channel:
return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
elif args.per_token and not args.per_channel:
return QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
return None
def create_quant_config(args: argparse.Namespace) -> QuantConfig:
quant_algo = compute_quant_algo(args)
GemmaForCausalLM.assert_valid_quant_algo(quant_algo)
quant_config = QuantConfig(quant_algo=quant_algo,
smoothquant_val=args.smoothquant)
if args.fp8_kv_cache:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
if args.int8_kv_cache:
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
if args.use_weight_only_with_precision:
use_awq = args.use_weight_only_with_precision.endswith("awq")
use_int4 = args.use_weight_only_with_precision.endswith("int4")
if use_awq:
quant_config.group_size = 128
if use_awq or use_int4 or not args.use_int8_weight_only_embedding:
quant_config.has_zero_point = False
quant_config.pre_quant_scale = True
else:
quant_config.exclude_modules = ['router']
return quant_config
def main() -> None:
args = parse_arguments()
tik = time.time()
quant_config = create_quant_config(args)
ckpt_parser = CKPT_PARSER[args.ckpt_type]()
mapping = Mapping(
world_size=args.world_size,
tp_size=args.world_size,
pp_size=1,
)
"""We don't support pipeline parallelism yet for Gemma."""
if isinstance(ckpt_parser, HfParser):
trt_llm_config = GemmaConfig.from_hugging_face(
args.model_dir,
args.dtype,
mapping=mapping,
quant_config=quant_config,
)
else:
print(f"Loading source parameters from {args.model_dir.absolute()}")
ckpt_params = ckpt_parser.load_parameters(args.model_dir)
input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
num_embed, _ = input_embedding_weights.shape
ckpt_params_dtype = str(input_embedding_weights.dtype).split(".")[
-1] # np.bfloat16 -> bfloat16
ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params,
num_embed)
# 2B TransformerConfig(num_layers=18, num_embed=256128, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=1)
# 7B TransformerConfig(...)
del ckpt_params
print(f"Source configuration determined from parameters: {ckpt_config}")
trt_llm_config = tensorrt_llm.models.GemmaConfig(
architecture=GEMMA_ARCHITECTURE,
dtype=args.dtype or ckpt_params_dtype,
logits_dtype="float32",
vocab_size=ckpt_config.num_embed,
max_position_embeddings=8192,
hidden_size=ckpt_config.embed_dim,
num_hidden_layers=ckpt_config.num_layers,
num_attention_heads=ckpt_config.num_heads,
num_key_value_heads=ckpt_config.num_kv_heads,
head_size=ckpt_config.head_dim,
hidden_act="gelu",
intermediate_size=ckpt_config.hidden_dim,
norm_epsilon=1e-6, # hard-coded in RMSNorm from gemma/layers.py
position_embedding_type="rope_gpt_neox",
mapping=mapping,
gpus_per_node=8,
quantization=quant_config,
use_parallel_embedding=mapping.tp_size > 1,
share_embedding_table=True,
)
if hasattr(ckpt_config,
"model_type") and ckpt_config.model_type == "gemma2":
trt_llm_config.inter_layernorms = True
trt_llm_config.final_logit_softcapping = ckpt_config.final_logit_softcapping
trt_llm_config.attn_logit_softcapping = ckpt_config.attn_logit_softcapping
trt_llm_config.query_pre_attn_scalar = ckpt_config.query_pre_attn_scalar
trt_llm_config_dict = trt_llm_config.to_dict()
print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")
save_config(trt_llm_config, output_dir=args.output_model_dir, log=True)
for config in trt_llm_config.for_each_rank():
hf_weights = load_gemma_weights(
parameters_or_model_dir=args.model_dir,
trt_llm_config=config,
ckpt_parser=ckpt_parser,
load_model_on_cpu=args.load_model_on_cpu)
ranked_weights = non_modelopt_quantize_if_needed(
hf_weights,
model_dir=args.model_dir,
quantize_modifiers=QuantizeModifiers.from_args(args),
trt_llm_config=config)
save_checkpoint(output_dir=args.output_model_dir,
weights=ranked_weights,
rank=config.mapping.rank)
elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
print(f"Total time of converting checkpoints: {elapsed}")
if __name__ == "__main__":
main()