TensorRT-LLMs/examples/enc_dec/bart/convert.py
Kaiyu Xie 250d9c293d
Update TensorRT-LLM Release branch (#1445)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-04-12 17:59:19 +08:00

293 lines
11 KiB
Python

import argparse
import configparser
import logging
import multiprocessing
import os
import re
from datetime import datetime
from pathlib import Path
dir_path = os.path.dirname(os.path.realpath(__file__))
import numpy as np
import torch # pytype: disable=import-error
from transformers import (AutoModelForSeq2SeqLM, MBartForConditionalGeneration,
VisionEncoderDecoderModel)
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.lora_manager import LoraConfig
LOGGER = logging.getLogger(__name__)
extra_configs = {
"structure": {
"t5_with_bias": "true",
"use_gated_activation": "false",
"position_embedding_type": "learned",
'model_type': 'bart'
}
} # TODO: remove model type as it's included in HF config's `architectures` attribute
# TODO: change name `t5_with_bias` for non-t5 model
def fuse_qkv(model, factor, saved_dir):
def get_attn_module(component, layer, attn_type):
m = model.model
m = getattr(m, component)
m = m.layers[int(layer)]
m = getattr(m, attn_type)
return m
for name, param in model.named_parameters():
if 'attn.q_proj.weight' in name:
# fuse weights of q, k, v
q_w = param
_, component, _, layer_idx, attn_type, *_ = name.split('.')
attn_mdl = get_attn_module(component, layer_idx, attn_type)
# fuse qkv weight
shape = q_w.shape # (do, din)
qkv_w = torch.cat(
[q_w, attn_mdl.k_proj.weight, attn_mdl.v_proj.weight],
dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din)
qkv_w = torch_to_numpy(qkv_w)
split_vals = np.split(qkv_w, factor,
axis=1) # TODO: need to test using multi-gpu
for j in range(factor):
saved_path = saved_dir / f"model.{component}.layers.{layer_idx}.{attn_type}.qkv_proj.weight.{j}.bin"
split_vals[j].tofile(saved_path.as_posix())
# fuse qkv biases if present
if hasattr(attn_mdl.q_proj, 'bias'):
q_b = attn_mdl.q_proj.bias
shape = q_b.shape[0] # (do,)
qkv_b = torch.cat(
[q_b, attn_mdl.k_proj.bias, attn_mdl.v_proj.bias],
dim=0).reshape([3, shape]) # (3, do)
qkv_b = torch_to_numpy(qkv_b)
split_vals = np.split(qkv_b, factor, axis=1) # (3, do / n_gpus)
for j in range(factor):
saved_path = saved_dir / f"model.{component}.layers.{layer_idx}.{attn_type}.qkv_proj.bias.{j}.bin"
split_vals[j].tofile(saved_path.as_posix())
# TODO: use re.compile to accelerate
def split_and_convert_process(key, val, factor, saved_dir):
saved_key = key
LOGGER.debug(f"key: {key}, val.shape: {val.shape}")
def save_splits(split_vals):
for j in range(factor):
saved_path = saved_dir / f"{saved_key}.{j:d}.bin"
split_vals[j].tofile(saved_path.as_posix())
if re.search('norm|embed_positions|(out_proj|fc2)\.bias', key) is not None:
saved_path = saved_dir / f"{saved_key}.bin"
if 'position' in key:
val = val[2:] # BART does not use first two position embeddings!
val.tofile(saved_path.as_posix())
elif re.search('(lm_head|fc1)\.(weight|bias)', key) is not None:
split_vals = np.split(val, factor, axis=0)
save_splits(split_vals)
elif re.search('[kqv]_proj\.(weight|bias)',
key) is not None: # No need to store, fuse later!
pass
elif re.search(
'(out_proj|fc2)\.weight',
key) is not None: # match attention o and ffn wo, split in dim 0
split_vals = np.split(
val, factor, axis=-1
) # no need to split bias, each GPU will add it individually after all reduce
save_splits(split_vals) # TODO: support gated activation?
elif re.search('(en|de)coder.embed_tokens.weight', key) is not None:
saved_path = saved_dir / f"{saved_key}.bin"
val.tofile(saved_path.as_posix())
elif 'final_logits_bias' in key: # buffer used to manually control emission prob?
pass
else:
LOGGER.warning(
f"cannot find key '{key}' with shape {val.shape}, no skip weight")
def convert_checkpoint(args):
# LoRA
encoder_hf_modules_to_trtllm_modules = {
"q_proj": "attn_q",
"v_proj": "attn_v",
} # encoder lora modules on bart
encoder_trtllm_modules_to_hf_modules = {
"attn_q": "q_proj",
"attn_v": "v_proj",
}
decoder_hf_modules_to_trtllm_modules = {
"q_proj": ["attn_q", "cross_attn_q"],
"v_proj": ["attn_v", "cross_attn_v"],
} # decoder lora modules on bart
decoder_trtllm_modules_to_hf_modules = {
"attn_q": "q_proj",
"attn_v": "v_proj",
"cross_attn_q": "q_proj",
"cross_attn_v": "v_proj",
}
encoder_lora_config = LoraConfig.from_hf(
args.hf_lora_dir, encoder_hf_modules_to_trtllm_modules,
encoder_trtllm_modules_to_hf_modules)
decoder_lora_config = LoraConfig.from_hf(
args.hf_lora_dir, decoder_hf_modules_to_trtllm_modules,
decoder_trtllm_modules_to_hf_modules)
args.encoder_lora_target_modules = encoder_lora_config.lora_target_modules
args.decoder_lora_target_modules = decoder_lora_config.lora_target_modules
args.encoder_max_lora_rank = 0
args.decoder_max_lora_rank = 0
if encoder_lora_config.is_valid:
args.encoder_max_lora_rank = encoder_lora_config.adapter_config['r']
if decoder_lora_config.is_valid:
args.decoder_max_lora_rank = decoder_lora_config.adapter_config['r']
# the lora checkpoint might finetune the embedding
args.encoder_vocab_size = encoder_lora_config.vocab_size
args.decoder_vocab_size = decoder_lora_config.vocab_size
args.encoder_lora_config = encoder_lora_config
args.decoder_lora_config = decoder_lora_config
saved_dir = Path(args.output_dir) / f"tp{args.inference_tensor_para_size}"
saved_dir.mkdir(parents=True, exist_ok=True)
if args.nougat:
model = VisionEncoderDecoderModel.from_pretrained(args.input_dir)
model = model.get_decoder()
else:
model = AutoModelForSeq2SeqLM.from_pretrained(args.input_dir)
model = model.to(str_dtype_to_torch(args.weight_data_type))
config = configparser.ConfigParser()
config['decoder'] = dict()
for key, val in model.model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
config["decoder"]["weight_data_type"] = args.weight_data_type
config["decoder"]["q_scaling"] = '1'
config["decoder"]["rescale_before_lm_head"] = str(False)
config['decoder']['has_model_final_layernorm'] = str(
args.nougat or isinstance(model, MBartForConditionalGeneration))
config["decoder"]["max_lora_rank"] = str(args.decoder_max_lora_rank)
config["decoder"]["lora_target_modules"] = str(
args.decoder_lora_target_modules)
config["decoder"]["hf_modules_to_trtllm_modules"] = str(
decoder_hf_modules_to_trtllm_modules)
config["decoder"]["trtllm_modules_to_hf_modules"] = str(
decoder_trtllm_modules_to_hf_modules)
if args.nougat:
# These flags are true for mbart decoders, but missing in HF config
config['decoder']['normalize_before'] = str(True)
config['decoder']['normalize_embeddings'] = str(True)
config['encoder'] = dict()
# Init few encoder configs, needed by build, from decoder config
encoder_config_keys = [
"encoder_ffn_dim", "encoder_layers", "encoder_attention_heads",
"encoder_layerdrop", "d_model"
]
for key in encoder_config_keys:
config['encoder'][key] = config['decoder'][key]
else:
config['encoder'] = dict()
for key, val in model.model.encoder.config.to_dict().items():
config["encoder"][key] = f"{val}"
config["encoder"]["weight_data_type"] = args.weight_data_type
config["encoder"]["q_scaling"] = '1'
# mBART has final layernorm, BART does not
config['encoder']['has_model_final_layernorm'] = str(
isinstance(model, MBartForConditionalGeneration))
config["encoder"]["max_lora_rank"] = str(args.encoder_max_lora_rank)
config["encoder"]["lora_target_modules"] = str(
args.encoder_lora_target_modules)
config["encoder"]["hf_modules_to_trtllm_modules"] = str(
encoder_hf_modules_to_trtllm_modules)
config["encoder"]["trtllm_modules_to_hf_modules"] = str(
encoder_trtllm_modules_to_hf_modules)
# add additional config
for key, val in extra_configs.items():
config[key] = {}
for val_key, val_val in val.items():
config[key][val_key] = val_val
with open((saved_dir / f"config.ini").as_posix(), 'w') as configfile:
config.write(configfile)
i_gpu_num = args.inference_tensor_para_size
pool = multiprocessing.Pool(args.processes)
pool.starmap_async(split_and_convert_process,
[(name, torch_to_numpy(param), i_gpu_num, saved_dir)
for name, param in model.state_dict().items()])
pool.close()
pool.join()
# fuse qkv weight and bias
fuse_qkv(model, i_gpu_num, saved_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--input_dir",
"-i",
type=str,
help="Path to the framework checkpoint file",
required=True)
parser.add_argument("--output_dir",
"-o",
type=str,
help="Path to the converted TRT-LLM model weight file",
required=True)
parser.add_argument("--inference_tensor_para_size",
"-i_g",
type=int,
help="How many gpus for inference",
required=True)
parser.add_argument(
"--processes",
"-p",
type=int,
help="How many processes to spawn for conversion (default: 4)",
default=4)
parser.add_argument("--weight_data_type",
type=str,
default="float32",
choices=["float32", "float16",
"bfloat16"]) # TODO: test support for bf16?
parser.add_argument("--hf_lora_dir", type=str, default=None)
parser.add_argument("--nougat",
action="store_true",
help="Model which uses vision encoder + mbart decoder")
parser.add_argument("--verbose",
action="store_true",
help="Provide verbose messages")
args = parser.parse_args()
log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,
format=log_format)
LOGGER.info("\n=============== Argument ===============")
for key in vars(args):
LOGGER.info(f"{key}: {vars(args)[key]}")
LOGGER.info("========================================")
start_time = datetime.now()
convert_checkpoint(args)
stop_time = datetime.now()
run_time = (stop_time - start_time)
LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time))