# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import os import sys import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import tensorrt_llm from tensorrt_llm._utils import release_gc from tensorrt_llm.layers import MoeConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import GrokForCausalLM from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization import QuantAlgo def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str, default=None) parser.add_argument('--weights_dir', type=str, default=None) parser.add_argument('--tp_size', type=int, default=1, help='N-way tensor parallelism size') parser.add_argument('--pp_size', type=int, default=1, help='N-way pipeline parallelism size') parser.add_argument('--dtype', type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) parser.add_argument('--vocab_size', type=int, default=32000) parser.add_argument('--n_positions', type=int, default=2048) parser.add_argument('--n_layer', type=int, default=32) parser.add_argument('--n_head', type=int, default=32) parser.add_argument('--n_kv_head', type=int, default=None) parser.add_argument('--n_embd', type=int, default=4096) parser.add_argument('--inter_size', type=int, default=11008) parser.add_argument('--rms_norm_eps', type=float, default=1e-06) parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--disable_weight_only_quant_plugin', default=False, action="store_true", help= 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument('--load_by_shard', action='store_true', help='Load a pretrained model shard-by-shard.') parser.add_argument('--hidden_act', type=str, default='silu') parser.add_argument('--rotary_base', type=float, default=10000.0) parser.add_argument( '--use_parallel_embedding', action="store_true", default=False, help= 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' ) parser.add_argument( '--embedding_sharding_dim', type=int, default=0, choices=[0, 1], help= 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' 'To shard it along hidden dimension, set embedding_sharding_dim=1' 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' ) parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', help='The path to save the TensorRT LLM checkpoint') parser.add_argument( '--workers', type=int, default=1, help='The number of workers for converting checkpoint in parallel') parser.add_argument( '--moe_num_experts', default=0, type=int, help='Specify the number of experts to use for MOE layers') parser.add_argument( '--moe_top_k', default=0, type=int, help= 'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set' ) parser.add_argument( '--moe_tp_size', type=int, default=-1, help= 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' ) parser.add_argument( '--moe_ep_size', type=int, default=-1, help= 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' ) parser.add_argument( '--moe_renorm_mode', default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, type=int, help= 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', ) parser.add_argument( '--save_config_only', action="store_true", default=False, help= 'Only save the model config w/o read and converting weights, be careful, this is for debug only' ) args = parser.parse_args() # changing the default to be consistent as the cli help said. if args.moe_num_experts and args.moe_top_k == 0: args.moe_top_k = 1 return args def args_to_quantization(args: argparse.Namespace) -> QuantConfig: '''return config dict with quantization info based on the command line args ''' quant_config = QuantConfig() if args.use_weight_only: if args.weight_only_precision == 'int8': quant_config.quant_algo = QuantAlgo.W8A16 return quant_config def args_to_build_options(args): return { 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, 'disable_weight_only_quant_plugin': args.disable_weight_only_quant_plugin } def from_cli_args(args): n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head config = { 'architecture': "LlamaForCausalLM", 'dtype': args.dtype, 'logits_dtype': 'float32', 'num_hidden_layers': args.n_layer, 'num_attention_heads': args.n_head, 'hidden_size': args.n_embd, 'intermediate_size': args.inter_size, 'num_key_value_heads': n_kv_head, 'vocab_size': args.vocab_size, 'position_embedding_type': 'rope_gpt_neox', 'max_position_embeddings': args.n_positions, 'hidden_act': args.hidden_act, 'rotary_base': args.rotary_base, 'norm_epsilon': args.rms_norm_eps, 'moe_num_experts': args.moe_num_experts, 'moe_top_k': args.moe_top_k, 'moe_normalization_mode': args.moe_renorm_mode, 'mapping': { 'world_size': args.tp_size * args.pp_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size, 'moe_tp_size': args.moe_tp_size, 'moe_ep_size': args.moe_ep_size, }, 'quantization': args_to_quantization(args).asdict() } config.update(args_to_build_options(args)) return config def preload_model(model_dir, weights_dir=None): sys.path.append(model_dir) from model import LanguageModelConfig, TransformerConfig from runners import ModelRunner if weights_dir and os.path.exists(weights_dir): CKPT_PATH = weights_dir else: CKPT_PATH = os.path.join(model_dir, "checkpoints") grok_1_model = LanguageModelConfig( vocab_size=128 * 1024, pad_token=0, eos_token=2, sequence_len=8192, embedding_init_scale=1.0, output_multiplier_scale=0.5773502691896257, embedding_multiplier_scale=78.38367176906169, model=TransformerConfig( emb_size=48 * 128, widening_factor=8, key_size=128, num_q_heads=48, num_kv_heads=8, num_layers=64, attn_output_multiplier=0.08838834764831845, shard_activations=True, # MoE. num_experts=8, num_selected_experts=2, # Activation sharding. data_axis="data", model_axis="model", ), ) runner = ModelRunner( model=grok_1_model, bs_per_device=0.125, checkpoint_path=CKPT_PATH, ) dummy_data = dict( inputs=np.zeros((1, 256), dtype=np.int32), targets=np.zeros((1, 256), dtype=np.int32), ) runner.transform_forward = True runner.initialize(dummy_data, (1, 8), (1, 1)) params = runner.load_or_init(dummy_data) return params def convert_and_save_xai(args): model_dir = args.model_dir load_by_shard = args.load_by_shard world_size = args.tp_size * args.pp_size if (args.moe_tp_size == -1 and args.moe_ep_size == -1): # moe default to tp-only args.moe_tp_size = args.tp_size args.moe_ep_size = 1 elif (args.moe_tp_size == -1): args.moe_tp_size = args.tp_size // args.moe_ep_size elif (args.moe_ep_size == -1): args.moe_ep_size = args.tp_size // args.moe_tp_size assert (args.moe_tp_size * args.moe_ep_size == args.tp_size ), "moe_tp_size * moe_ep_size must equal to tp_size" # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, # before the refactor is done. override_fields = {} quantization = args_to_quantization(args) override_fields.update(args_to_build_options(args)) # When not loading by shard, preload one complete model and then slice per rank weights from this # this saves the disk reloading time xai_model = preload_model( model_dir, args.weights_dir) if not args.load_by_shard else None def convert_and_save_rank(args, rank): mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, pp_size=args.pp_size, moe_tp_size=args.moe_tp_size, moe_ep_size=args.moe_ep_size) grok = GrokForCausalLM.from_hugging_face( model_dir, args.dtype, mapping=mapping, quantization=quantization, load_by_shard=load_by_shard, override_fields=override_fields, preloaded_model=xai_model, ) grok.save_checkpoint(args.output_dir, save_config=(rank == 0)) del grok execute(args.workers, [convert_and_save_rank] * world_size, args) release_gc() def execute(workers, func, args): if workers == 1: for rank, f in enumerate(func): f(args, rank) else: with ThreadPoolExecutor(max_workers=workers) as p: futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] exceptions = [] for future in as_completed(futures): try: future.result() except Exception as e: traceback.print_exc() exceptions.append(e) assert len( exceptions ) == 0, "Checkpoint conversion failed, please check error log." def main(): print(tensorrt_llm.__version__) args = parse_arguments() args.tp_size * args.pp_size tik = time.time() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.model_dir is None: # generate fake config.json config = from_cli_args(args) with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=4) else: # all other non-gptq paths from hf model assert args.model_dir is not None convert_and_save_xai(args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') if __name__ == '__main__': main()