mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
491 lines
20 KiB
Python
491 lines
20 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
|
||
from tqdm import tqdm
|
||
from transformers import LlamaConfig
|
||
|
||
import tensorrt_llm
|
||
from tensorrt_llm.mapping import Mapping
|
||
from tensorrt_llm.models.eagle.config import EagleConfig
|
||
from tensorrt_llm.models.eagle.model import EagleForCausalLM
|
||
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
|
||
from tensorrt_llm.quantization import QuantAlgo
|
||
|
||
|
||
def parse_arguments():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--model_dir', type=str, default=None)
|
||
parser.add_argument('--meta_ckpt_dir', type=str, default=None)
|
||
parser.add_argument('--tp_size',
|
||
type=int,
|
||
default=1,
|
||
help='N-way tensor parallelism size')
|
||
parser.add_argument('--pp_size',
|
||
type=int,
|
||
default=1,
|
||
help='N-way pipeline parallelism size')
|
||
parser.add_argument('--dtype',
|
||
type=str,
|
||
default='auto',
|
||
choices=['auto', 'float16', 'bfloat16', 'float32'])
|
||
parser.add_argument('--vocab_size', type=int, default=32000)
|
||
parser.add_argument('--n_positions', type=int, default=2048)
|
||
parser.add_argument('--n_layer', type=int, default=32)
|
||
|
||
parser.add_argument(
|
||
'--use_weight_only',
|
||
default=False,
|
||
action="store_true",
|
||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||
'See --weight_only_precision to set the precision')
|
||
parser.add_argument(
|
||
'--weight_only_precision',
|
||
const='int8',
|
||
type=str,
|
||
nargs='?',
|
||
default='int8',
|
||
choices=['int8', 'int4', 'int4_gptq'],
|
||
help=
|
||
'Define the precision for the weights when using weight-only quantization.'
|
||
'You must also use --use_weight_only for that argument to have an impact.'
|
||
)
|
||
parser.add_argument(
|
||
'--calib_dataset',
|
||
type=str,
|
||
default='ccdv/cnn_dailymail',
|
||
help=
|
||
"The huggingface dataset name or the local directory of the dataset for calibration."
|
||
)
|
||
parser.add_argument(
|
||
"--smoothquant",
|
||
"-sq",
|
||
type=float,
|
||
default=None,
|
||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||
" to Smoothquant the model, and output int8 weights."
|
||
" A good first try is 0.5. Must be in [0, 1]")
|
||
parser.add_argument(
|
||
'--per_channel',
|
||
action="store_true",
|
||
default=False,
|
||
help=
|
||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||
'per_channel instead uses a different static scaling factor for each channel. '
|
||
'The latter is usually more accurate, but a little slower.')
|
||
parser.add_argument(
|
||
'--per_token',
|
||
action="store_true",
|
||
default=False,
|
||
help=
|
||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||
'The latter is usually more accurate, but a little slower.')
|
||
parser.add_argument(
|
||
'--int8_kv_cache',
|
||
default=False,
|
||
action="store_true",
|
||
help=
|
||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--per_group',
|
||
default=False,
|
||
action="store_true",
|
||
help=
|
||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||
'The flag is built for GPTQ/AWQ quantization.')
|
||
|
||
parser.add_argument('--load_by_shard',
|
||
action='store_true',
|
||
help='Load a pretrained model shard-by-shard.')
|
||
parser.add_argument('--hidden_act', type=str, default='silu')
|
||
|
||
parser.add_argument('--rotary_base', type=float, default=10000.0)
|
||
parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None)
|
||
|
||
parser.add_argument('--group_size',
|
||
type=int,
|
||
default=128,
|
||
help='Group size used in GPTQ/AWQ quantization.')
|
||
|
||
parser.add_argument("--storage-type",
|
||
"-t",
|
||
type=str,
|
||
default="fp32",
|
||
choices=["fp32", "fp16"])
|
||
parser.add_argument("--dataset-cache-dir",
|
||
type=str,
|
||
default=None,
|
||
help="cache dir to load the hugging face dataset")
|
||
parser.add_argument("--load-model-on-cpu", action="store_true")
|
||
parser.add_argument("--convert-model-on-cpu", action="store_true")
|
||
parser.add_argument(
|
||
'--use_parallel_embedding',
|
||
action="store_true",
|
||
default=False,
|
||
help=
|
||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||
)
|
||
parser.add_argument(
|
||
'--embedding_sharding_dim',
|
||
type=int,
|
||
default=0,
|
||
choices=[0, 1],
|
||
help=
|
||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||
)
|
||
parser.add_argument('--output_dir',
|
||
type=str,
|
||
default='tllm_checkpoint',
|
||
help='The path to save the TensorRT LLM checkpoint')
|
||
parser.add_argument(
|
||
'--workers',
|
||
type=int,
|
||
default=1,
|
||
help='The number of workers for converting checkpoint in parallel')
|
||
|
||
parser.add_argument('--eagle_model_dir', type=str, default=None)
|
||
parser.add_argument('--max_draft_len', type=int, default=63)
|
||
parser.add_argument(
|
||
'--num_eagle_layers',
|
||
type=int,
|
||
default=4,
|
||
help=
|
||
'Maximum depth of the EAGLE choices tree, i.e. maximum number of accepted draft tokens.'
|
||
)
|
||
parser.add_argument(
|
||
'--max_non_leaves_per_layer',
|
||
type=int,
|
||
default=10,
|
||
help='Maximum number of non-leaf nodes in the EAGLE choice tree.')
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
def convert_and_save_hf(config, args):
|
||
world_size = args.tp_size * args.pp_size
|
||
tllm_config = EagleConfig.from_dict(config)
|
||
for rank in range(world_size):
|
||
tllm_config.mapping = Mapping(world_size=world_size,
|
||
rank=rank,
|
||
cp_size=1,
|
||
tp_size=args.tp_size,
|
||
pp_size=args.pp_size)
|
||
|
||
model = EagleForCausalLM(tllm_config)
|
||
|
||
def check_and_update(module, dict):
|
||
if hasattr(module, 'tllm_to_externel_key_dict'):
|
||
module.tllm_to_externel_key_dict.update(dict)
|
||
else:
|
||
module.tllm_to_externel_key_dict = dict
|
||
|
||
def copy(tensors):
|
||
if isinstance(tensors, list):
|
||
if None in tensors:
|
||
return tensors
|
||
else:
|
||
return [tensor.clone() for tensor in tensors]
|
||
elif tensors is None:
|
||
return tensors
|
||
else:
|
||
return tensors.clone()
|
||
|
||
shared_weight_prefixs = []
|
||
tllm_weights = {}
|
||
customized_dict = {"drafter": ""}
|
||
if args.eagle_model_dir is None:
|
||
# Single checkpoint for ModelOpt
|
||
for idx, eagle_net in enumerate(model.eagle_nets):
|
||
check_and_update(eagle_net.drafter.fc, {"fc": "fc"})
|
||
check_and_update(eagle_net.drafter.vocab_embedding,
|
||
{f"eagle_nets.{idx}": "model"})
|
||
check_and_update(eagle_net.lm_head, {f"eagle_nets.{idx}": ""})
|
||
shared_weight_prefixs.append(f"eagle_nets.{idx}")
|
||
customized_dict[f'eagle_nets.{idx}'] = 'eagle_module'
|
||
loader = ModelWeightsLoader(eagle_model_dir, customized_dict)
|
||
loader.update_key_mapping(model)
|
||
for tllm_key, _ in tqdm(model.named_parameters()):
|
||
if any([
|
||
tllm_key.startswith(prefix)
|
||
for prefix in shared_weight_prefixs
|
||
]):
|
||
tllm_weights.update(loader.load(tllm_key, preprocess=copy))
|
||
else:
|
||
tllm_weights.update(loader.load(tllm_key))
|
||
loader.fill(tllm_weights)
|
||
else:
|
||
# Double checkpoint for HF
|
||
for idx, eagle_net in enumerate(model.eagle_nets):
|
||
check_and_update(eagle_net.drafter.fc, {"fc": "fc"})
|
||
check_and_update(eagle_net.drafter.vocab_embedding,
|
||
{f"eagle_nets.{idx}": ""})
|
||
check_and_update(eagle_net.lm_head, {f"eagle_nets.{idx}": ""})
|
||
shared_weight_prefixs.append(f"eagle_nets.{idx}")
|
||
customized_dict[f'eagle_nets.{idx}'] = ''
|
||
|
||
# Load base model
|
||
base_loader = ModelWeightsLoader(args.model_dir)
|
||
base_loader.update_key_mapping(model)
|
||
for tllm_key, _ in tqdm(model.transformer.named_parameters()):
|
||
tllm_weights.update(base_loader.load("transformer." + tllm_key))
|
||
tllm_weights.update(base_loader.load("lm_head.weight"))
|
||
for idx in range(args.num_eagle_layers):
|
||
tllm_weights.update(
|
||
base_loader.load(f"eagle_nets.{idx}.lm_head.weight",
|
||
preprocess=copy))
|
||
|
||
# Load eagle model
|
||
eagle_loader = ModelWeightsLoader(eagle_model_dir, customized_dict)
|
||
eagle_loader.update_key_mapping(model)
|
||
for tllm_key, _ in tqdm(model.eagle_nets.named_parameters()):
|
||
if not tllm_key.endswith("lm_head.weight"):
|
||
if any([
|
||
tllm_key.startswith(prefix)
|
||
for prefix in shared_weight_prefixs
|
||
]):
|
||
tllm_weights.update(
|
||
eagle_loader.load("eagle_nets." + tllm_key,
|
||
preprocess=copy))
|
||
else:
|
||
tllm_weights.update(
|
||
eagle_loader.load("eagle_nets." + tllm_key))
|
||
base_loader.fill(tllm_weights)
|
||
model.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# TODO(qijun): Currently, the convert script depends on a torch op:
|
||
# torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix,
|
||
# which is included in tensorrt_llm Python package. Otherwise, the convert
|
||
# script does not need to import tensorrt_llm. Will remove it after reimplementing
|
||
# the op with PyTorch.
|
||
print(tensorrt_llm.__version__)
|
||
args = parse_arguments()
|
||
world_size = args.tp_size * args.pp_size
|
||
|
||
assert args.pp_size == 1, "Pipeline parallelism is not supported in EAGLE yet."
|
||
|
||
tik = time.time()
|
||
|
||
if not os.path.exists(args.output_dir):
|
||
os.makedirs(args.output_dir)
|
||
hf_config = None
|
||
eagle_model_dir = args.model_dir if args.eagle_model_dir is None else args.eagle_model_dir
|
||
if args.model_dir is not None:
|
||
hf_config = LlamaConfig.from_pretrained(args.model_dir)
|
||
|
||
args.model_type = hf_config.model_type
|
||
args.n_head = hf_config.num_attention_heads
|
||
args.inter_size = hf_config.intermediate_size
|
||
args.n_layer = hf_config.num_hidden_layers
|
||
args.n_embd = hf_config.hidden_size
|
||
args.n_kv_head = hf_config.num_key_value_heads
|
||
args.rms_norm_eps = hf_config.rms_norm_eps
|
||
args.vocab_size = hf_config.vocab_size
|
||
args.rotary_scaling = hf_config.rope_scaling
|
||
args.rotary_base = hf_config.rope_theta
|
||
args.n_positions = hf_config.max_position_embeddings
|
||
args.dtype = str(
|
||
hf_config.torch_dtype)[6:] if args.dtype == 'auto' else args.dtype
|
||
if 'head_dim' in hf_config:
|
||
args.head_dim = hf_config.head_dim
|
||
else:
|
||
args.head_dim = args.n_embd // args.n_head
|
||
if 'head_size' in hf_config:
|
||
args.head_size = hf_config.head_size
|
||
else:
|
||
args.head_size = args.head_dim
|
||
|
||
if args.eagle_model_dir is None:
|
||
hf_config_eagle = hf_config.eagle
|
||
args.n_head_eagle = hf_config_eagle['num_attention_heads']
|
||
args.inter_size_eagle = hf_config_eagle['intermediate_size']
|
||
args.n_layer_eagle = hf_config_eagle['num_hidden_layers']
|
||
args.n_embd_eagle = hf_config_eagle['hidden_size']
|
||
args.n_kv_head_eagle = hf_config_eagle['num_key_value_heads']
|
||
args.rms_norm_eps_eagle = hf_config_eagle['rms_norm_eps']
|
||
args.n_positions_eagle = hf_config_eagle['max_position_embeddings']
|
||
if 'head_dim' in hf_config_eagle:
|
||
args.head_dim_eagle = hf_config_eagle['head_dim']
|
||
else:
|
||
args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle
|
||
if 'head_size' in hf_config_eagle:
|
||
args.head_size_eagle = hf_config_eagle['head_size']
|
||
else:
|
||
args.head_size_eagle = args.head_dim_eagle
|
||
else:
|
||
hf_config_eagle = LlamaConfig.from_pretrained(args.eagle_model_dir)
|
||
args.n_head_eagle = hf_config_eagle.num_attention_heads
|
||
args.inter_size_eagle = hf_config_eagle.intermediate_size
|
||
args.n_layer_eagle = hf_config_eagle.num_hidden_layers
|
||
args.n_embd_eagle = hf_config_eagle.hidden_size
|
||
args.n_kv_head_eagle = hf_config_eagle.num_key_value_heads
|
||
args.rms_norm_eps_eagle = hf_config_eagle.rms_norm_eps
|
||
args.n_positions_eagle = hf_config_eagle.max_position_embeddings
|
||
if 'head_dim' in hf_config_eagle:
|
||
args.head_dim_eagle = hf_config_eagle.head_dim
|
||
else:
|
||
args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle
|
||
if 'head_size' in hf_config_eagle:
|
||
args.head_size_eagle = hf_config_eagle.head_size
|
||
else:
|
||
args.head_size_eagle = args.head_dim_eagle
|
||
|
||
elif args.meta_ckpt_dir is not None:
|
||
assert False, "meta ckpt is not supported yet"
|
||
|
||
with open(Path(args.meta_ckpt_dir, "params.json")) as fp:
|
||
meta_config: dict = json.load(fp)
|
||
args.n_embd = meta_config["dim"]
|
||
args.n_head = meta_config["n_heads"]
|
||
args.n_layer = meta_config["n_layers"]
|
||
args.n_kv_head = meta_config.get("n_kv_heads", args.n_head)
|
||
|
||
if "hidden_dim" in meta_config:
|
||
args.inter_size = meta_config["hidden_dim"]
|
||
else:
|
||
args.multiple_of = meta_config.get("multiple_of", 1)
|
||
n_embd = int(4 * args.n_embd * 2 / 3)
|
||
args.ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1)
|
||
args.inter_size = args.multiple_of * (
|
||
(int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1)
|
||
// args.multiple_of)
|
||
args.rms_norm_eps = meta_config["norm_eps"]
|
||
|
||
if args.rotary_scaling is not None:
|
||
# assert args.use_gpt_attention_plugin, "RoPE scaling is only supported through GPT attention plugin."
|
||
rotary_scaling = {
|
||
"type": args.rotary_scaling["rope_type"],
|
||
}
|
||
args.rotary_scaling = rotary_scaling
|
||
|
||
eagle_net_config = {
|
||
'architecture': "LlamaForCausalLM",
|
||
'dtype': args.dtype,
|
||
'logits_dtype': 'float32',
|
||
'num_hidden_layers': args.n_layer_eagle,
|
||
'num_attention_heads': args.n_head_eagle,
|
||
'hidden_size': args.n_embd_eagle,
|
||
'intermediate_size': args.inter_size_eagle,
|
||
'num_key_value_heads': args.n_kv_head_eagle,
|
||
'vocab_size': args.vocab_size,
|
||
'position_embedding_type': 'rope_gpt_neox',
|
||
'max_position_embeddings': args.n_positions_eagle,
|
||
'hidden_act': args.hidden_act,
|
||
'rotary_base': args.rotary_base,
|
||
'rotary_scaling': args.rotary_scaling,
|
||
'norm_epsilon': args.rms_norm_eps_eagle,
|
||
'quantization': {
|
||
'quant_algo': None,
|
||
'kv_cache_quant_algo': None,
|
||
},
|
||
'mapping': {
|
||
'world_size': world_size,
|
||
'tp_size': args.tp_size,
|
||
'pp_size': args.pp_size,
|
||
},
|
||
'use_parallel_embedding': args.use_parallel_embedding,
|
||
'embedding_sharding_dim': args.embedding_sharding_dim,
|
||
'head_dim': args.head_dim_eagle,
|
||
'head_size': args.head_size_eagle
|
||
}
|
||
|
||
config = {
|
||
'architecture': 'EagleForCausalLM',
|
||
'dtype': args.dtype,
|
||
'logits_dtype': 'float32',
|
||
'num_hidden_layers': args.n_layer,
|
||
'num_attention_heads': args.n_head,
|
||
'hidden_size': args.n_embd,
|
||
'intermediate_size': args.inter_size,
|
||
'num_key_value_heads': args.n_kv_head,
|
||
'vocab_size': args.vocab_size,
|
||
'position_embedding_type': 'rope_gpt_neox',
|
||
'max_position_embeddings': args.n_positions,
|
||
'hidden_act': args.hidden_act,
|
||
'rotary_base': args.rotary_base,
|
||
'rotary_scaling': args.rotary_scaling,
|
||
'norm_epsilon': args.rms_norm_eps,
|
||
'quantization': {
|
||
'quant_algo': None,
|
||
'kv_cache_quant_algo': None,
|
||
},
|
||
'mapping': {
|
||
'world_size': world_size,
|
||
'tp_size': args.tp_size,
|
||
'pp_size': args.pp_size,
|
||
},
|
||
'use_parallel_embedding': args.use_parallel_embedding,
|
||
'embedding_sharding_dim': args.embedding_sharding_dim,
|
||
'max_draft_len': args.max_draft_len,
|
||
'num_eagle_layers': args.num_eagle_layers,
|
||
'max_non_leaves_per_layer': args.max_non_leaves_per_layer,
|
||
'eagle_net_config': eagle_net_config,
|
||
'head_dim': args.head_dim,
|
||
'head_size': args.head_size
|
||
}
|
||
|
||
assert args.max_draft_len <= 256, "args.max_draft_len > 256 is not supported"
|
||
|
||
if args.use_weight_only:
|
||
if args.weight_only_precision == 'int8':
|
||
config['quantization']['quant_algo'] = QuantAlgo.W8A16
|
||
elif args.weight_only_precision == 'int4':
|
||
config['quantization']['quant_algo'] = QuantAlgo.W4A16
|
||
elif args.smoothquant:
|
||
if args.per_channel:
|
||
if args.per_token:
|
||
config['quantization'][
|
||
'quant_algo'] = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
||
else:
|
||
config['quantization'][
|
||
'quant_algo'] = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
||
else:
|
||
if args.per_token:
|
||
config['quantization'][
|
||
'quant_algo'] = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
||
else:
|
||
config['quantization'][
|
||
'quant_algo'] = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
||
|
||
if args.int8_kv_cache:
|
||
config['quantization']['kv_cache_quant_algo'] = QuantAlgo.INT8
|
||
|
||
if args.weight_only_precision == 'int4_gptq':
|
||
config['quantization'].update({
|
||
"group_size": args.group_size,
|
||
"has_zero_point": True,
|
||
"pre_quant_scale": False,
|
||
'quant_algo': QuantAlgo.W4A16_GPTQ
|
||
})
|
||
|
||
# Update quant config if hf_quant_config.json exists
|
||
quant_config = {}
|
||
try:
|
||
with open(eagle_model_dir + '/' + 'hf_quant_config.json') as f:
|
||
quant_config = json.load(f)
|
||
if "lm_head" in quant_config['quantization']['exclude_modules']:
|
||
quant_config['quantization']['exclude_modules'] += [
|
||
f"eagle_nets.{i}.lm_head"
|
||
for i in range(args.num_eagle_layers)
|
||
]
|
||
config['quantization'].update(quant_config['quantization'])
|
||
config['eagle_net_config']['quantization'].update(
|
||
quant_config['quantization'])
|
||
except IOError:
|
||
pass
|
||
|
||
convert_and_save_hf(config, args)
|
||
|
||
tok = time.time()
|
||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||
print(f'Total time of converting checkpoints: {t}')
|