mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
365 lines
13 KiB
Python
365 lines
13 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import numpy as np
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._utils import release_gc
|
|
from tensorrt_llm.layers import MoeConfig
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models import GrokForCausalLM
|
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
|
from tensorrt_llm.quantization import QuantAlgo
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_dir', type=str, default=None)
|
|
parser.add_argument('--weights_dir', type=str, default=None)
|
|
|
|
parser.add_argument('--tp_size',
|
|
type=int,
|
|
default=1,
|
|
help='N-way tensor parallelism size')
|
|
parser.add_argument('--pp_size',
|
|
type=int,
|
|
default=1,
|
|
help='N-way pipeline parallelism size')
|
|
parser.add_argument('--dtype',
|
|
type=str,
|
|
default='float16',
|
|
choices=['float32', 'bfloat16', 'float16'])
|
|
parser.add_argument('--vocab_size', type=int, default=32000)
|
|
parser.add_argument('--n_positions', type=int, default=2048)
|
|
parser.add_argument('--n_layer', type=int, default=32)
|
|
parser.add_argument('--n_head', type=int, default=32)
|
|
parser.add_argument('--n_kv_head', type=int, default=None)
|
|
parser.add_argument('--n_embd', type=int, default=4096)
|
|
parser.add_argument('--inter_size', type=int, default=11008)
|
|
parser.add_argument('--rms_norm_eps', type=float, default=1e-06)
|
|
|
|
parser.add_argument(
|
|
'--use_weight_only',
|
|
default=False,
|
|
action="store_true",
|
|
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
|
'See --weight_only_precision to set the precision')
|
|
parser.add_argument(
|
|
'--disable_weight_only_quant_plugin',
|
|
default=False,
|
|
action="store_true",
|
|
help=
|
|
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
|
'You must also use --use_weight_only for that argument to have an impact.'
|
|
)
|
|
parser.add_argument(
|
|
'--weight_only_precision',
|
|
const='int8',
|
|
type=str,
|
|
nargs='?',
|
|
default='int8',
|
|
choices=['int8'],
|
|
help=
|
|
'Define the precision for the weights when using weight-only quantization.'
|
|
'You must also use --use_weight_only for that argument to have an impact.'
|
|
)
|
|
|
|
parser.add_argument('--load_by_shard',
|
|
action='store_true',
|
|
help='Load a pretrained model shard-by-shard.')
|
|
parser.add_argument('--hidden_act', type=str, default='silu')
|
|
|
|
parser.add_argument('--rotary_base', type=float, default=10000.0)
|
|
|
|
parser.add_argument(
|
|
'--use_parallel_embedding',
|
|
action="store_true",
|
|
default=False,
|
|
help=
|
|
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
|
)
|
|
parser.add_argument(
|
|
'--embedding_sharding_dim',
|
|
type=int,
|
|
default=0,
|
|
choices=[0, 1],
|
|
help=
|
|
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
|
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
|
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
|
)
|
|
parser.add_argument('--output_dir',
|
|
type=str,
|
|
default='tllm_checkpoint',
|
|
help='The path to save the TensorRT-LLM checkpoint')
|
|
parser.add_argument(
|
|
'--workers',
|
|
type=int,
|
|
default=1,
|
|
help='The number of workers for converting checkpoint in parallel')
|
|
parser.add_argument(
|
|
'--moe_num_experts',
|
|
default=0,
|
|
type=int,
|
|
help='Specify the number of experts to use for MOE layers')
|
|
parser.add_argument(
|
|
'--moe_top_k',
|
|
default=0,
|
|
type=int,
|
|
help=
|
|
'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set'
|
|
)
|
|
parser.add_argument(
|
|
'--moe_tp_size',
|
|
type=int,
|
|
default=-1,
|
|
help=
|
|
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
|
)
|
|
parser.add_argument(
|
|
'--moe_ep_size',
|
|
type=int,
|
|
default=-1,
|
|
help=
|
|
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
|
)
|
|
parser.add_argument(
|
|
'--moe_renorm_mode',
|
|
default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
|
|
type=int,
|
|
help=
|
|
'Controls renormalization after gate logits. Check layers/moe.py for accepted values',
|
|
)
|
|
parser.add_argument(
|
|
'--save_config_only',
|
|
action="store_true",
|
|
default=False,
|
|
help=
|
|
'Only save the model config w/o read and converting weights, be careful, this is for debug only'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
# changing the default to be consistent as the cli help said.
|
|
if args.moe_num_experts and args.moe_top_k == 0:
|
|
args.moe_top_k = 1
|
|
return args
|
|
|
|
|
|
def args_to_quantization(args: argparse.Namespace) -> QuantConfig:
|
|
'''return config dict with quantization info based on the command line args
|
|
'''
|
|
quant_config = QuantConfig()
|
|
if args.use_weight_only:
|
|
if args.weight_only_precision == 'int8':
|
|
quant_config.quant_algo = QuantAlgo.W8A16
|
|
|
|
return quant_config
|
|
|
|
|
|
def args_to_build_options(args):
|
|
return {
|
|
'use_parallel_embedding': args.use_parallel_embedding,
|
|
'embedding_sharding_dim': args.embedding_sharding_dim,
|
|
'disable_weight_only_quant_plugin':
|
|
args.disable_weight_only_quant_plugin
|
|
}
|
|
|
|
|
|
def from_cli_args(args):
|
|
n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head
|
|
config = {
|
|
'architecture': "LlamaForCausalLM",
|
|
'dtype': args.dtype,
|
|
'logits_dtype': 'float32',
|
|
'num_hidden_layers': args.n_layer,
|
|
'num_attention_heads': args.n_head,
|
|
'hidden_size': args.n_embd,
|
|
'intermediate_size': args.inter_size,
|
|
'num_key_value_heads': n_kv_head,
|
|
'vocab_size': args.vocab_size,
|
|
'position_embedding_type': 'rope_gpt_neox',
|
|
'max_position_embeddings': args.n_positions,
|
|
'hidden_act': args.hidden_act,
|
|
'rotary_base': args.rotary_base,
|
|
'norm_epsilon': args.rms_norm_eps,
|
|
'moe_num_experts': args.moe_num_experts,
|
|
'moe_top_k': args.moe_top_k,
|
|
'moe_normalization_mode': args.moe_renorm_mode,
|
|
'mapping': {
|
|
'world_size': args.tp_size * args.pp_size,
|
|
'tp_size': args.tp_size,
|
|
'pp_size': args.pp_size,
|
|
'moe_tp_size': args.moe_tp_size,
|
|
'moe_ep_size': args.moe_ep_size,
|
|
},
|
|
'quantization': args_to_quantization(args).asdict()
|
|
}
|
|
config.update(args_to_build_options(args))
|
|
return config
|
|
|
|
|
|
def preload_model(model_dir, weights_dir=None):
|
|
sys.path.append(model_dir)
|
|
from model import LanguageModelConfig, TransformerConfig
|
|
from runners import ModelRunner
|
|
if weights_dir and os.path.exists(weights_dir):
|
|
CKPT_PATH = weights_dir
|
|
else:
|
|
CKPT_PATH = os.path.join(model_dir, "checkpoints")
|
|
|
|
grok_1_model = LanguageModelConfig(
|
|
vocab_size=128 * 1024,
|
|
pad_token=0,
|
|
eos_token=2,
|
|
sequence_len=8192,
|
|
embedding_init_scale=1.0,
|
|
output_multiplier_scale=0.5773502691896257,
|
|
embedding_multiplier_scale=78.38367176906169,
|
|
model=TransformerConfig(
|
|
emb_size=48 * 128,
|
|
widening_factor=8,
|
|
key_size=128,
|
|
num_q_heads=48,
|
|
num_kv_heads=8,
|
|
num_layers=64,
|
|
attn_output_multiplier=0.08838834764831845,
|
|
shard_activations=True,
|
|
# MoE.
|
|
num_experts=8,
|
|
num_selected_experts=2,
|
|
# Activation sharding.
|
|
data_axis="data",
|
|
model_axis="model",
|
|
),
|
|
)
|
|
|
|
runner = ModelRunner(
|
|
model=grok_1_model,
|
|
bs_per_device=0.125,
|
|
checkpoint_path=CKPT_PATH,
|
|
)
|
|
dummy_data = dict(
|
|
inputs=np.zeros((1, 256), dtype=np.int32),
|
|
targets=np.zeros((1, 256), dtype=np.int32),
|
|
)
|
|
runner.transform_forward = True
|
|
runner.initialize(dummy_data, (1, 8), (1, 1))
|
|
|
|
params = runner.load_or_init(dummy_data)
|
|
|
|
return params
|
|
|
|
|
|
def convert_and_save_xai(args):
|
|
model_dir = args.model_dir
|
|
load_by_shard = args.load_by_shard
|
|
world_size = args.tp_size * args.pp_size
|
|
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
|
|
# moe default to tp-only
|
|
args.moe_tp_size = args.tp_size
|
|
args.moe_ep_size = 1
|
|
elif (args.moe_tp_size == -1):
|
|
args.moe_tp_size = args.tp_size // args.moe_ep_size
|
|
elif (args.moe_ep_size == -1):
|
|
args.moe_ep_size = args.tp_size // args.moe_tp_size
|
|
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
|
|
), "moe_tp_size * moe_ep_size must equal to tp_size"
|
|
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
|
|
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
|
|
# before the refactor is done.
|
|
override_fields = {}
|
|
quantization = args_to_quantization(args)
|
|
override_fields.update(args_to_build_options(args))
|
|
|
|
# When not loading by shard, preload one complete model and then slice per rank weights from this
|
|
# this saves the disk reloading time
|
|
xai_model = preload_model(
|
|
model_dir, args.weights_dir) if not args.load_by_shard else None
|
|
|
|
def convert_and_save_rank(args, rank):
|
|
mapping = Mapping(world_size=world_size,
|
|
rank=rank,
|
|
tp_size=args.tp_size,
|
|
pp_size=args.pp_size,
|
|
moe_tp_size=args.moe_tp_size,
|
|
moe_ep_size=args.moe_ep_size)
|
|
grok = GrokForCausalLM.from_hugging_face(
|
|
model_dir,
|
|
args.dtype,
|
|
mapping=mapping,
|
|
quantization=quantization,
|
|
load_by_shard=load_by_shard,
|
|
override_fields=override_fields,
|
|
preloaded_model=xai_model,
|
|
)
|
|
grok.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
|
del grok
|
|
|
|
execute(args.workers, [convert_and_save_rank] * world_size, args)
|
|
release_gc()
|
|
|
|
|
|
def execute(workers, func, args):
|
|
if workers == 1:
|
|
for rank, f in enumerate(func):
|
|
f(args, rank)
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=workers) as p:
|
|
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
|
exceptions = []
|
|
for future in as_completed(futures):
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
exceptions.append(e)
|
|
assert len(
|
|
exceptions
|
|
) == 0, "Checkpoint conversion failed, please check error log."
|
|
|
|
|
|
def main():
|
|
print(tensorrt_llm.__version__)
|
|
args = parse_arguments()
|
|
|
|
args.tp_size * args.pp_size
|
|
tik = time.time()
|
|
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
|
|
if args.model_dir is None: # generate fake config.json
|
|
config = from_cli_args(args)
|
|
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
|
|
json.dump(config, f, indent=4)
|
|
else: # all other non-gptq paths from hf model
|
|
assert args.model_dir is not None
|
|
convert_and_save_xai(args)
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
print(f'Total time of converting checkpoints: {t}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|