mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* TensorRT-LLM Release 0.10.0 --------- Co-authored-by: Loki <lokravi@amazon.com> Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
517 lines
22 KiB
Python
517 lines
22 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import math
|
|
import re
|
|
import time
|
|
import typing
|
|
from pathlib import Path
|
|
|
|
import flax
|
|
import numpy as np
|
|
import orbax
|
|
import safetensors.torch
|
|
import torch
|
|
from recurrentgemma import jax as recurrentgemma_jax
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm._utils import (numpy_to_torch, str_dtype_to_torch,
|
|
torch_to_numpy)
|
|
|
|
LOGGER = logging.getLogger("convert_checkpoint")
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--ckpt_type", type=str, choices=["jax", "hf"])
|
|
parser.add_argument("--model_dir", type=Path, default=None)
|
|
parser.add_argument("--world_size",
|
|
type=int,
|
|
default=1,
|
|
help="world size, only support tensor parallelism now")
|
|
parser.add_argument("--dtype",
|
|
type=str,
|
|
default="float16",
|
|
choices=["float32", "bfloat16", "float16"])
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=Path,
|
|
default="recurrentgemma_tllm_checkpoint",
|
|
help="The path to save the recurrentgemma TensorRT-LLM checkpoint")
|
|
parser.add_argument("--log_level", type=str, default="info")
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
class JAXParser:
|
|
|
|
def load_parameters(self, checkpoint_path: Path):
|
|
checkpoint_path = checkpoint_path.absolute()
|
|
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
|
params = checkpointer.restore(checkpoint_path)
|
|
return params
|
|
|
|
def embedding_weights(self, ckpt_params):
|
|
return ckpt_params["embedder"]["input_embedding"]
|
|
|
|
def get_config(self, checkpoint_path, ckpt_params):
|
|
config = recurrentgemma_jax.GriffinConfig.from_flax_params_or_variables(
|
|
ckpt_params)._asdict()
|
|
if config["lru_width"] is None:
|
|
config["lru_width"] = config["width"]
|
|
layer_types = []
|
|
for p in config["block_types"]:
|
|
if p == recurrentgemma_jax.TemporalBlockType.ATTENTION:
|
|
layer_types.append("attention")
|
|
else:
|
|
layer_types.append("recurrent")
|
|
config["block_types"] = layer_types
|
|
config["hidden_size"] = config.pop("width")
|
|
config["num_attention_heads"] = config.pop("num_heads")
|
|
config["intermediate_size"] = config.pop("mlp_expanded_width")
|
|
config["num_hidden_layers"] = len(config["block_types"])
|
|
return config
|
|
|
|
def rename_to_trt_llm(self, name: str):
|
|
"""Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name."""
|
|
sub_patterns = (
|
|
(r"embedder.input_embedding", r"vocab_embedding.weight"),
|
|
(r"blocks.(\d+).channel_pre_norm.scale",
|
|
r"layers.\1.post_layernorm.weight"),
|
|
(r"blocks.(\d+).temporal_pre_norm.scale",
|
|
r"layers.\1.input_layernorm.weight"),
|
|
(r"blocks.(\d+).recurrent_block.conv_1d.w",
|
|
r"layers.\1.recurrent.conv1d.weight"),
|
|
(r"blocks.(\d+).recurrent_block.conv_1d.b",
|
|
r"layers.\1.recurrent.conv1d.bias"),
|
|
(r"blocks.(\d+).recurrent_block.linear_out.kernel",
|
|
r"layers.\1.recurrent.linear_out.weight"),
|
|
(r"blocks.(\d+).recurrent_block.linear_out.bias",
|
|
r"layers.\1.recurrent.linear_out.bias"),
|
|
(r"blocks.(\d+).recurrent_block.linear_x.kernel",
|
|
r"layers.\1.recurrent.linear_x.weight"),
|
|
(r"blocks.(\d+).recurrent_block.linear_x.bias",
|
|
r"layers.\1.recurrent.linear_x.bias"),
|
|
(r"blocks.(\d+).recurrent_block.linear_y.kernel",
|
|
r"layers.\1.recurrent.linear_y.weight"),
|
|
(r"blocks.(\d+).recurrent_block.linear_y.bias",
|
|
r"layers.\1.recurrent.y_bias"),
|
|
(r"blocks.(\d+).recurrent_block.rg_lru.a_gate.w",
|
|
r"layers.\1.recurrent.rg_lru.recurrent_gate.weight"),
|
|
(r"blocks.(\d+).recurrent_block.rg_lru.a_gate.b",
|
|
r"layers.\1.recurrent.rg_lru.recurrent_gate.bias"),
|
|
(r"blocks.(\d+).recurrent_block.rg_lru.input_gate.w",
|
|
r"layers.\1.recurrent.rg_lru.input_gate.weight"),
|
|
(r"blocks.(\d+).recurrent_block.rg_lru.input_gate.b",
|
|
r"layers.\1.recurrent.rg_lru.input_gate.bias"),
|
|
(r"blocks.(\d+).recurrent_block.rg_lru.a_param",
|
|
r"layers.\1.recurrent.rg_lru.recurrent_param"),
|
|
(r"blocks.(\d+).mlp_block.ffw_up.w", r"layers.\1.mlp.fc.weight"),
|
|
(r"blocks.(\d+).mlp_block.ffw_up.b", None),
|
|
(r"blocks.(\d+).mlp_block.ffw_down.kernel",
|
|
r"layers.\1.mlp.proj.weight"),
|
|
(r"blocks.(\d+).mlp_block.ffw_down.bias",
|
|
r"layers.\1.mlp.proj.bias"),
|
|
(r"blocks.(\d+).attention_block.proj_q.kernel",
|
|
r"layers.\1.attention.qkv.weight"),
|
|
(r"blocks.(\d+).attention_block.proj_k.kernel", None),
|
|
(r"blocks.(\d+).attention_block.proj_v.kernel", None),
|
|
(r"blocks.(\d+).attention_block.proj_final.kernel",
|
|
r"layers.\1.attention.dense.weight"),
|
|
(r"blocks.(\d+).attention_block.proj_final.bias",
|
|
r"layers.\1.attention.dense.bias"),
|
|
(r"final_norm.scale", r"ln_f.weight"),
|
|
)
|
|
|
|
for source, target in sub_patterns:
|
|
if re.match(source, name):
|
|
if target is None:
|
|
return target
|
|
else:
|
|
name = re.sub(source, target, name)
|
|
return ".".join(("transformer", name))
|
|
else:
|
|
raise ValueError(f"Don't know how to rename {name}")
|
|
|
|
def flatten_params(self, params):
|
|
new_params = flax.traverse_util.flatten_dict(params, sep=".")
|
|
# if the dtype is bfloat16, cast to float32
|
|
for k in new_params:
|
|
if new_params[k].dtype != np.float32 and new_params[
|
|
k].dtype != np.float16:
|
|
new_params[k] = new_params[k].astype(np.float32)
|
|
return new_params
|
|
|
|
|
|
class HfParser:
|
|
|
|
def load_parameters(self, checkpoint_path: Path):
|
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
checkpoint_path,
|
|
device_map="auto",
|
|
torch_dtype="auto",
|
|
)
|
|
model_params = dict(hf_model.named_parameters())
|
|
return model_params
|
|
|
|
def embedding_weights(self, ckpt_params):
|
|
return ckpt_params["model.embed_tokens.weight"]
|
|
|
|
def get_config(self, checkpoint_path, ckpt_params):
|
|
checkpoint_path = checkpoint_path.absolute()
|
|
hf_config = AutoConfig.from_pretrained(
|
|
checkpoint_path, trust_remote_code=True).to_dict()
|
|
hf_config["block_types"] = hf_config.pop("_block_types")
|
|
hf_config["intermediate_size"] = hf_config["intermediate_size"] // 2
|
|
return hf_config
|
|
|
|
def rename_to_trt_llm(self, name: str):
|
|
"""Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name."""
|
|
sub_patterns = (
|
|
(r"model.embed_tokens.weight", r"vocab_embedding.weight"),
|
|
(r"model.layers.(\d+).temporal_pre_norm.weight",
|
|
r"layers.\1.input_layernorm.weight"),
|
|
(r"model.layers.(\d+).channel_pre_norm.weight",
|
|
r"layers.\1.post_layernorm.weight"),
|
|
(r"model.layers.(\d+).temporal_block.conv_1d.weight",
|
|
r"layers.\1.recurrent.conv1d.weight"),
|
|
(r"model.layers.(\d+).temporal_block.conv_1d.bias",
|
|
r"layers.\1.recurrent.conv1d.bias"),
|
|
(r"model.layers.(\d+).temporal_block.linear_out.weight",
|
|
r"layers.\1.recurrent.linear_out.weight"),
|
|
(r"model.layers.(\d+).temporal_block.linear_out.bias",
|
|
r"layers.\1.recurrent.linear_out.bias"),
|
|
(r"model.layers.(\d+).temporal_block.linear_x.weight",
|
|
r"layers.\1.recurrent.linear_x.weight"),
|
|
(r"model.layers.(\d+).temporal_block.linear_x.bias",
|
|
r"layers.\1.recurrent.linear_x.bias"),
|
|
(r"model.layers.(\d+).temporal_block.linear_y.weight",
|
|
r"layers.\1.recurrent.linear_y.weight"),
|
|
(r"model.layers.(\d+).temporal_block.linear_y.bias",
|
|
r"layers.\1.recurrent.y_bias"),
|
|
(r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_weight",
|
|
r"layers.\1.recurrent.rg_lru.recurrent_gate.weight"),
|
|
(r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_bias",
|
|
r"layers.\1.recurrent.rg_lru.recurrent_gate.bias"),
|
|
(r"model.layers.(\d+).temporal_block.rg_lru.input_gate_weight",
|
|
r"layers.\1.recurrent.rg_lru.input_gate.weight"),
|
|
(r"model.layers.(\d+).temporal_block.rg_lru.input_gate_bias",
|
|
r"layers.\1.recurrent.rg_lru.input_gate.bias"),
|
|
(r"model.layers.(\d+).temporal_block.rg_lru.recurrent_param",
|
|
r"layers.\1.recurrent.rg_lru.recurrent_param"),
|
|
(r"model.layers.(\d+).mlp_block.up_proj.weight",
|
|
r"layers.\1.mlp.gate.weight"),
|
|
(r"model.layers.(\d+).mlp_block.up_proj.bias",
|
|
r"layers.\1.mlp.gate.bias"),
|
|
(r"model.layers.(\d+).mlp_block.gate_proj.weight",
|
|
r"layers.\1.mlp.fc.weight"),
|
|
(r"model.layers.(\d+).mlp_block.gate_proj.bias",
|
|
r"layers.\1.mlp.fc.bias"),
|
|
(r"model.layers.(\d+).mlp_block.down_proj.weight",
|
|
r"layers.\1.mlp.proj.weight"),
|
|
(r"model.layers.(\d+).mlp_block.down_proj.bias",
|
|
r"layers.\1.mlp.proj.bias"),
|
|
(r"model.layers.(\d+).temporal_block.q_proj.weight",
|
|
r"layers.\1.attention.qkv.weight"),
|
|
(r"model.layers.(\d+).temporal_block.k_proj.weight", None),
|
|
(r"model.layers.(\d+).temporal_block.v_proj.weight", None),
|
|
(r"model.layers.(\d+).temporal_block.o_proj.weight",
|
|
r"layers.\1.attention.dense.weight"),
|
|
(r"model.layers.(\d+).temporal_block.o_proj.bias",
|
|
r"layers.\1.attention.dense.bias"),
|
|
(r"model.final_norm.weight", r"ln_f.weight"),
|
|
)
|
|
|
|
for source, target in sub_patterns:
|
|
if re.match(source, name):
|
|
if target is None:
|
|
return target
|
|
else:
|
|
name = re.sub(source, target, name)
|
|
return ".".join(("transformer", name))
|
|
else:
|
|
raise ValueError(f"Don't know how to rename {name}")
|
|
|
|
def flatten_params(self, params):
|
|
f_params = {}
|
|
for k, v in params.items():
|
|
if v.dtype == torch.bfloat16:
|
|
v = v.float()
|
|
f_params[k] = torch_to_numpy(v)
|
|
return f_params
|
|
|
|
|
|
CKPT_PARSER = {
|
|
"jax": JAXParser,
|
|
"hf": HfParser,
|
|
}
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0):
|
|
if tp_size == 1:
|
|
return v
|
|
return np.split(v, tp_size, axis=dim)[idx]
|
|
|
|
|
|
def split_matrix_tp(v, tensor_parallel, rank, dim):
|
|
return split(v, tensor_parallel, rank, dim=dim)
|
|
|
|
|
|
def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
|
|
name: str,
|
|
param: np.ndarray,
|
|
dtype: typing.Optional[np.dtype] = None):
|
|
assert name not in weights, f"{name} is already added."
|
|
param = numpy_to_torch(param)
|
|
if dtype is not None:
|
|
assert isinstance(dtype,
|
|
str), f"dtype must be str, but get type {type(dtype)}"
|
|
param = param.to(str_dtype_to_torch(dtype))
|
|
weights[name] = param.contiguous()
|
|
|
|
|
|
def convert_from_checkpoint(
|
|
trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
|
|
model_dir: typing.Union[str, Path],
|
|
ckpt_parser,
|
|
rank=0,
|
|
):
|
|
print("Loading weights...")
|
|
tik = time.time()
|
|
|
|
tp_rank = rank
|
|
tp_size = trt_llm_config.mapping.tp_size
|
|
intermediate_size = trt_llm_config.intermediate_size
|
|
rnn_hidden_size = trt_llm_config.rnn_hidden_size
|
|
conv_kernel = trt_llm_config.conv_kernel
|
|
|
|
weights = {}
|
|
for model_file in [model_dir]:
|
|
LOGGER.debug(f"Loading directory {str(model_file)}...")
|
|
model_params = ckpt_parser.load_parameters(model_file)
|
|
model_params = ckpt_parser.flatten_params(model_params)
|
|
|
|
for name, param in model_params.items():
|
|
LOGGER.debug(f"Converting weight {name}...")
|
|
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
|
|
if trt_llm_name is None: # omit as used with other params
|
|
continue
|
|
|
|
if "proj_q" in name or "q_proj" in name:
|
|
if isinstance(ckpt_parser, JAXParser):
|
|
k_name = name.replace("proj_q", "proj_k")
|
|
v_name = name.replace("proj_q", "proj_v")
|
|
q_param = param.transpose(1, 0)
|
|
k_param = model_params[k_name].transpose(1, 0)
|
|
v_param = model_params[v_name].transpose(1, 0)
|
|
else:
|
|
k_name = name.replace("q_proj", "k_proj")
|
|
v_name = name.replace("q_proj", "v_proj")
|
|
q_param = param
|
|
k_param = model_params[k_name]
|
|
v_param = model_params[v_name]
|
|
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
|
|
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
elif "ffw_up.w" in name and isinstance(ckpt_parser, JAXParser):
|
|
bias_name = name.replace("ffw_up.w", "ffw_up.b")
|
|
fc_param, gate_param = param[0, ].transpose(
|
|
1, 0), param[1, ].transpose(1, 0)
|
|
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
|
|
gate_param = split_matrix_tp(gate_param,
|
|
tp_size,
|
|
tp_rank,
|
|
dim=0)
|
|
fc_bias = model_params[bias_name][0, ].reshape(
|
|
intermediate_size)
|
|
gate_bias = model_params[bias_name][1, ].reshape(
|
|
intermediate_size)
|
|
fc_bias = split_matrix_tp(fc_bias, tp_size, tp_rank, dim=0)
|
|
gate_bias = split_matrix_tp(gate_bias, tp_size, tp_rank, dim=0)
|
|
trt_llm_fc_name = trt_llm_name
|
|
trt_llm_gate_name = trt_llm_name.replace(
|
|
"fc.weight", "gate.weight")
|
|
trt_llm_fc_b_name = trt_llm_name.replace("fc.weight", "fc.bias")
|
|
trt_llm_gate_b_name = trt_llm_name.replace(
|
|
"fc.weight", "gate.bias")
|
|
add_trt_llm_weight(weights, trt_llm_fc_name, fc_param,
|
|
trt_llm_config.dtype)
|
|
add_trt_llm_weight(weights, trt_llm_gate_name, gate_param,
|
|
trt_llm_config.dtype)
|
|
add_trt_llm_weight(weights, trt_llm_fc_b_name, fc_bias,
|
|
trt_llm_config.dtype)
|
|
add_trt_llm_weight(weights, trt_llm_gate_b_name, gate_bias,
|
|
trt_llm_config.dtype)
|
|
elif "conv_1d.w" in name:
|
|
if isinstance(ckpt_parser, JAXParser):
|
|
param = param.transpose(1,
|
|
0).reshape(rnn_hidden_size, 1,
|
|
conv_kernel, 1)
|
|
else:
|
|
param = param.reshape(rnn_hidden_size, 1, conv_kernel, 1)
|
|
param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif "embedder.input_embedding" in name or "model.embed_tokens" in name:
|
|
if not trt_llm_config.share_embedding_table:
|
|
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
|
|
add_trt_llm_weight(weights, "lm_head.weight",
|
|
np.copy(lm_head), trt_llm_config.dtype)
|
|
if trt_llm_config.emb_scale_by_sqrt_dim:
|
|
param = np.multiply(
|
|
param.astype(np.float32),
|
|
math.sqrt(trt_llm_config.hidden_size),
|
|
)
|
|
if trt_llm_config.use_parallel_embedding:
|
|
assert trt_llm_config.vocab_size % tp_size == 0
|
|
param = split_matrix_tp(
|
|
param,
|
|
tp_size,
|
|
tp_rank,
|
|
dim=trt_llm_config.embedding_sharding_dim,
|
|
)
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif any(keyword in name for keyword in (
|
|
"proj_final.kernel",
|
|
"ffw_down.kernel",
|
|
"linear_out.kernel",
|
|
"o_proj.weight",
|
|
"down_proj.weight",
|
|
"linear_out.weight",
|
|
)):
|
|
if isinstance(ckpt_parser, JAXParser):
|
|
param = param.transpose(1, 0)
|
|
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif any(keyword in name for keyword in (
|
|
"linear_x.kernel",
|
|
"linear_y.kernel",
|
|
"linear_x.weight",
|
|
"linear_y.weight",
|
|
"up_proj.weight",
|
|
"gate_proj.weight",
|
|
)):
|
|
if isinstance(ckpt_parser, JAXParser):
|
|
param = param.transpose(1, 0)
|
|
param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif any(keyword in name for keyword in (
|
|
"linear_x.bias",
|
|
"linear_y.bias",
|
|
"rg_lru",
|
|
"conv_1d.b",
|
|
"gate_proj.bias",
|
|
"up_proj.bias",
|
|
)):
|
|
param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif any(keyword in name for keyword in (
|
|
"channel_pre_norm",
|
|
"temporal_pre_norm",
|
|
"final_norm",
|
|
)):
|
|
param = param + 1.0
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif any(keyword in name for keyword in (
|
|
"proj_final.bias",
|
|
"ffw_down.bias",
|
|
"linear_out.bias",
|
|
"o_proj.bias",
|
|
"down_proj.bias",
|
|
)):
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
else:
|
|
raise RuntimeError(f"Unhandled {name} module weights")
|
|
del model_params
|
|
print(
|
|
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
|
|
)
|
|
return weights
|
|
|
|
|
|
def convert(worker_rank, args, convert_kwargs):
|
|
for rank in range(worker_rank, args.world_size):
|
|
weights = convert_from_checkpoint(rank=rank, **convert_kwargs)
|
|
safetensors.torch.save_file(weights,
|
|
args.output_dir / f"rank{rank}.safetensors")
|
|
|
|
|
|
def main():
|
|
print(tensorrt_llm.__version__)
|
|
|
|
args = parse_arguments()
|
|
logger.set_level(args.log_level)
|
|
tik = time.time()
|
|
|
|
print(f"Loading source parameters from {args.model_dir.absolute()}")
|
|
ckpt_parser = CKPT_PARSER[args.ckpt_type]()
|
|
ckpt_params = ckpt_parser.load_parameters(args.model_dir)
|
|
input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
|
|
ckpt_params_dtype = str(input_embedding_weights.dtype).split(".")[-1]
|
|
ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params)
|
|
|
|
print(f"Source configuration determined from parameters: {ckpt_config}")
|
|
|
|
quant_config = tensorrt_llm.models.modeling_utils.QuantConfig()
|
|
trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig(
|
|
architecture="RecurrentGemmaForCausalLM",
|
|
dtype=args.dtype or ckpt_params_dtype,
|
|
logits_dtype="float32",
|
|
vocab_size=ckpt_config["vocab_size"],
|
|
# follow the setting of gemma models
|
|
max_position_embeddings=8192,
|
|
hidden_size=ckpt_config["hidden_size"],
|
|
num_hidden_layers=ckpt_config["num_hidden_layers"],
|
|
num_attention_heads=ckpt_config["num_attention_heads"],
|
|
num_key_value_heads=1,
|
|
head_size=ckpt_config["hidden_size"] //
|
|
ckpt_config["num_attention_heads"],
|
|
hidden_act="gelu",
|
|
intermediate_size=ckpt_config["intermediate_size"],
|
|
norm_epsilon=1e-6,
|
|
position_embedding_type="rope_gpt_neox",
|
|
world_size=args.world_size,
|
|
tp_size=args.world_size,
|
|
pp_size=1,
|
|
gpus_per_node=8,
|
|
quantization=quant_config,
|
|
conv_kernel=4,
|
|
state_size=1,
|
|
state_dtype='float32',
|
|
rotary_pct=0.5,
|
|
layer_types=ckpt_config["block_types"],
|
|
rnn_hidden_size=ckpt_config["lru_width"],
|
|
logits_soft_cap=ckpt_config["logits_soft_cap"],
|
|
emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"],
|
|
)
|
|
|
|
trt_llm_config_dict = trt_llm_config.to_dict()
|
|
print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")
|
|
|
|
config_path = args.output_dir / "config.json"
|
|
config_path.parent.mkdir(exist_ok=True, parents=True)
|
|
LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}")
|
|
with config_path.open("w") as config_file:
|
|
json.dump(trt_llm_config_dict, config_file, indent=4)
|
|
|
|
convert_args = dict(trt_llm_config=trt_llm_config,
|
|
model_dir=args.model_dir,
|
|
ckpt_parser=ckpt_parser)
|
|
convert(0, args, convert_args)
|
|
|
|
elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
|
|
print(f"Total time of converting checkpoints: {elapsed}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|