TensorRT-LLMs/examples/recurrentgemma/convert_checkpoint.py
Kaiyu Xie 9bd15f1937
TensorRT-LLM v0.10 update
* TensorRT-LLM Release 0.10.0

---------

Co-authored-by: Loki <lokravi@amazon.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-06-05 20:43:25 +08:00

517 lines
22 KiB
Python

import argparse
import json
import logging
import math
import re
import time
import typing
from pathlib import Path
import flax
import numpy as np
import orbax
import safetensors.torch
import torch
from recurrentgemma import jax as recurrentgemma_jax
from transformers import AutoConfig, AutoModelForCausalLM
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm._utils import (numpy_to_torch, str_dtype_to_torch,
torch_to_numpy)
LOGGER = logging.getLogger("convert_checkpoint")
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_type", type=str, choices=["jax", "hf"])
parser.add_argument("--model_dir", type=Path, default=None)
parser.add_argument("--world_size",
type=int,
default=1,
help="world size, only support tensor parallelism now")
parser.add_argument("--dtype",
type=str,
default="float16",
choices=["float32", "bfloat16", "float16"])
parser.add_argument(
"--output_dir",
type=Path,
default="recurrentgemma_tllm_checkpoint",
help="The path to save the recurrentgemma TensorRT-LLM checkpoint")
parser.add_argument("--log_level", type=str, default="info")
args = parser.parse_args()
return args
class JAXParser:
def load_parameters(self, checkpoint_path: Path):
checkpoint_path = checkpoint_path.absolute()
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(checkpoint_path)
return params
def embedding_weights(self, ckpt_params):
return ckpt_params["embedder"]["input_embedding"]
def get_config(self, checkpoint_path, ckpt_params):
config = recurrentgemma_jax.GriffinConfig.from_flax_params_or_variables(
ckpt_params)._asdict()
if config["lru_width"] is None:
config["lru_width"] = config["width"]
layer_types = []
for p in config["block_types"]:
if p == recurrentgemma_jax.TemporalBlockType.ATTENTION:
layer_types.append("attention")
else:
layer_types.append("recurrent")
config["block_types"] = layer_types
config["hidden_size"] = config.pop("width")
config["num_attention_heads"] = config.pop("num_heads")
config["intermediate_size"] = config.pop("mlp_expanded_width")
config["num_hidden_layers"] = len(config["block_types"])
return config
def rename_to_trt_llm(self, name: str):
"""Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name."""
sub_patterns = (
(r"embedder.input_embedding", r"vocab_embedding.weight"),
(r"blocks.(\d+).channel_pre_norm.scale",
r"layers.\1.post_layernorm.weight"),
(r"blocks.(\d+).temporal_pre_norm.scale",
r"layers.\1.input_layernorm.weight"),
(r"blocks.(\d+).recurrent_block.conv_1d.w",
r"layers.\1.recurrent.conv1d.weight"),
(r"blocks.(\d+).recurrent_block.conv_1d.b",
r"layers.\1.recurrent.conv1d.bias"),
(r"blocks.(\d+).recurrent_block.linear_out.kernel",
r"layers.\1.recurrent.linear_out.weight"),
(r"blocks.(\d+).recurrent_block.linear_out.bias",
r"layers.\1.recurrent.linear_out.bias"),
(r"blocks.(\d+).recurrent_block.linear_x.kernel",
r"layers.\1.recurrent.linear_x.weight"),
(r"blocks.(\d+).recurrent_block.linear_x.bias",
r"layers.\1.recurrent.linear_x.bias"),
(r"blocks.(\d+).recurrent_block.linear_y.kernel",
r"layers.\1.recurrent.linear_y.weight"),
(r"blocks.(\d+).recurrent_block.linear_y.bias",
r"layers.\1.recurrent.y_bias"),
(r"blocks.(\d+).recurrent_block.rg_lru.a_gate.w",
r"layers.\1.recurrent.rg_lru.recurrent_gate.weight"),
(r"blocks.(\d+).recurrent_block.rg_lru.a_gate.b",
r"layers.\1.recurrent.rg_lru.recurrent_gate.bias"),
(r"blocks.(\d+).recurrent_block.rg_lru.input_gate.w",
r"layers.\1.recurrent.rg_lru.input_gate.weight"),
(r"blocks.(\d+).recurrent_block.rg_lru.input_gate.b",
r"layers.\1.recurrent.rg_lru.input_gate.bias"),
(r"blocks.(\d+).recurrent_block.rg_lru.a_param",
r"layers.\1.recurrent.rg_lru.recurrent_param"),
(r"blocks.(\d+).mlp_block.ffw_up.w", r"layers.\1.mlp.fc.weight"),
(r"blocks.(\d+).mlp_block.ffw_up.b", None),
(r"blocks.(\d+).mlp_block.ffw_down.kernel",
r"layers.\1.mlp.proj.weight"),
(r"blocks.(\d+).mlp_block.ffw_down.bias",
r"layers.\1.mlp.proj.bias"),
(r"blocks.(\d+).attention_block.proj_q.kernel",
r"layers.\1.attention.qkv.weight"),
(r"blocks.(\d+).attention_block.proj_k.kernel", None),
(r"blocks.(\d+).attention_block.proj_v.kernel", None),
(r"blocks.(\d+).attention_block.proj_final.kernel",
r"layers.\1.attention.dense.weight"),
(r"blocks.(\d+).attention_block.proj_final.bias",
r"layers.\1.attention.dense.bias"),
(r"final_norm.scale", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join(("transformer", name))
else:
raise ValueError(f"Don't know how to rename {name}")
def flatten_params(self, params):
new_params = flax.traverse_util.flatten_dict(params, sep=".")
# if the dtype is bfloat16, cast to float32
for k in new_params:
if new_params[k].dtype != np.float32 and new_params[
k].dtype != np.float16:
new_params[k] = new_params[k].astype(np.float32)
return new_params
class HfParser:
def load_parameters(self, checkpoint_path: Path):
hf_model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
device_map="auto",
torch_dtype="auto",
)
model_params = dict(hf_model.named_parameters())
return model_params
def embedding_weights(self, ckpt_params):
return ckpt_params["model.embed_tokens.weight"]
def get_config(self, checkpoint_path, ckpt_params):
checkpoint_path = checkpoint_path.absolute()
hf_config = AutoConfig.from_pretrained(
checkpoint_path, trust_remote_code=True).to_dict()
hf_config["block_types"] = hf_config.pop("_block_types")
hf_config["intermediate_size"] = hf_config["intermediate_size"] // 2
return hf_config
def rename_to_trt_llm(self, name: str):
"""Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name."""
sub_patterns = (
(r"model.embed_tokens.weight", r"vocab_embedding.weight"),
(r"model.layers.(\d+).temporal_pre_norm.weight",
r"layers.\1.input_layernorm.weight"),
(r"model.layers.(\d+).channel_pre_norm.weight",
r"layers.\1.post_layernorm.weight"),
(r"model.layers.(\d+).temporal_block.conv_1d.weight",
r"layers.\1.recurrent.conv1d.weight"),
(r"model.layers.(\d+).temporal_block.conv_1d.bias",
r"layers.\1.recurrent.conv1d.bias"),
(r"model.layers.(\d+).temporal_block.linear_out.weight",
r"layers.\1.recurrent.linear_out.weight"),
(r"model.layers.(\d+).temporal_block.linear_out.bias",
r"layers.\1.recurrent.linear_out.bias"),
(r"model.layers.(\d+).temporal_block.linear_x.weight",
r"layers.\1.recurrent.linear_x.weight"),
(r"model.layers.(\d+).temporal_block.linear_x.bias",
r"layers.\1.recurrent.linear_x.bias"),
(r"model.layers.(\d+).temporal_block.linear_y.weight",
r"layers.\1.recurrent.linear_y.weight"),
(r"model.layers.(\d+).temporal_block.linear_y.bias",
r"layers.\1.recurrent.y_bias"),
(r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_weight",
r"layers.\1.recurrent.rg_lru.recurrent_gate.weight"),
(r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_bias",
r"layers.\1.recurrent.rg_lru.recurrent_gate.bias"),
(r"model.layers.(\d+).temporal_block.rg_lru.input_gate_weight",
r"layers.\1.recurrent.rg_lru.input_gate.weight"),
(r"model.layers.(\d+).temporal_block.rg_lru.input_gate_bias",
r"layers.\1.recurrent.rg_lru.input_gate.bias"),
(r"model.layers.(\d+).temporal_block.rg_lru.recurrent_param",
r"layers.\1.recurrent.rg_lru.recurrent_param"),
(r"model.layers.(\d+).mlp_block.up_proj.weight",
r"layers.\1.mlp.gate.weight"),
(r"model.layers.(\d+).mlp_block.up_proj.bias",
r"layers.\1.mlp.gate.bias"),
(r"model.layers.(\d+).mlp_block.gate_proj.weight",
r"layers.\1.mlp.fc.weight"),
(r"model.layers.(\d+).mlp_block.gate_proj.bias",
r"layers.\1.mlp.fc.bias"),
(r"model.layers.(\d+).mlp_block.down_proj.weight",
r"layers.\1.mlp.proj.weight"),
(r"model.layers.(\d+).mlp_block.down_proj.bias",
r"layers.\1.mlp.proj.bias"),
(r"model.layers.(\d+).temporal_block.q_proj.weight",
r"layers.\1.attention.qkv.weight"),
(r"model.layers.(\d+).temporal_block.k_proj.weight", None),
(r"model.layers.(\d+).temporal_block.v_proj.weight", None),
(r"model.layers.(\d+).temporal_block.o_proj.weight",
r"layers.\1.attention.dense.weight"),
(r"model.layers.(\d+).temporal_block.o_proj.bias",
r"layers.\1.attention.dense.bias"),
(r"model.final_norm.weight", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join(("transformer", name))
else:
raise ValueError(f"Don't know how to rename {name}")
def flatten_params(self, params):
f_params = {}
for k, v in params.items():
if v.dtype == torch.bfloat16:
v = v.float()
f_params[k] = torch_to_numpy(v)
return f_params
CKPT_PARSER = {
"jax": JAXParser,
"hf": HfParser,
}
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
return np.split(v, tp_size, axis=dim)[idx]
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
name: str,
param: np.ndarray,
dtype: typing.Optional[np.dtype] = None):
assert name not in weights, f"{name} is already added."
param = numpy_to_torch(param)
if dtype is not None:
assert isinstance(dtype,
str), f"dtype must be str, but get type {type(dtype)}"
param = param.to(str_dtype_to_torch(dtype))
weights[name] = param.contiguous()
def convert_from_checkpoint(
trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
model_dir: typing.Union[str, Path],
ckpt_parser,
rank=0,
):
print("Loading weights...")
tik = time.time()
tp_rank = rank
tp_size = trt_llm_config.mapping.tp_size
intermediate_size = trt_llm_config.intermediate_size
rnn_hidden_size = trt_llm_config.rnn_hidden_size
conv_kernel = trt_llm_config.conv_kernel
weights = {}
for model_file in [model_dir]:
LOGGER.debug(f"Loading directory {str(model_file)}...")
model_params = ckpt_parser.load_parameters(model_file)
model_params = ckpt_parser.flatten_params(model_params)
for name, param in model_params.items():
LOGGER.debug(f"Converting weight {name}...")
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
if trt_llm_name is None: # omit as used with other params
continue
if "proj_q" in name or "q_proj" in name:
if isinstance(ckpt_parser, JAXParser):
k_name = name.replace("proj_q", "proj_k")
v_name = name.replace("proj_q", "proj_v")
q_param = param.transpose(1, 0)
k_param = model_params[k_name].transpose(1, 0)
v_param = model_params[v_name].transpose(1, 0)
else:
k_name = name.replace("q_proj", "k_proj")
v_name = name.replace("q_proj", "v_proj")
q_param = param
k_param = model_params[k_name]
v_param = model_params[v_name]
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "ffw_up.w" in name and isinstance(ckpt_parser, JAXParser):
bias_name = name.replace("ffw_up.w", "ffw_up.b")
fc_param, gate_param = param[0, ].transpose(
1, 0), param[1, ].transpose(1, 0)
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
gate_param = split_matrix_tp(gate_param,
tp_size,
tp_rank,
dim=0)
fc_bias = model_params[bias_name][0, ].reshape(
intermediate_size)
gate_bias = model_params[bias_name][1, ].reshape(
intermediate_size)
fc_bias = split_matrix_tp(fc_bias, tp_size, tp_rank, dim=0)
gate_bias = split_matrix_tp(gate_bias, tp_size, tp_rank, dim=0)
trt_llm_fc_name = trt_llm_name
trt_llm_gate_name = trt_llm_name.replace(
"fc.weight", "gate.weight")
trt_llm_fc_b_name = trt_llm_name.replace("fc.weight", "fc.bias")
trt_llm_gate_b_name = trt_llm_name.replace(
"fc.weight", "gate.bias")
add_trt_llm_weight(weights, trt_llm_fc_name, fc_param,
trt_llm_config.dtype)
add_trt_llm_weight(weights, trt_llm_gate_name, gate_param,
trt_llm_config.dtype)
add_trt_llm_weight(weights, trt_llm_fc_b_name, fc_bias,
trt_llm_config.dtype)
add_trt_llm_weight(weights, trt_llm_gate_b_name, gate_bias,
trt_llm_config.dtype)
elif "conv_1d.w" in name:
if isinstance(ckpt_parser, JAXParser):
param = param.transpose(1,
0).reshape(rnn_hidden_size, 1,
conv_kernel, 1)
else:
param = param.reshape(rnn_hidden_size, 1, conv_kernel, 1)
param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "embedder.input_embedding" in name or "model.embed_tokens" in name:
if not trt_llm_config.share_embedding_table:
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, "lm_head.weight",
np.copy(lm_head), trt_llm_config.dtype)
if trt_llm_config.emb_scale_by_sqrt_dim:
param = np.multiply(
param.astype(np.float32),
math.sqrt(trt_llm_config.hidden_size),
)
if trt_llm_config.use_parallel_embedding:
assert trt_llm_config.vocab_size % tp_size == 0
param = split_matrix_tp(
param,
tp_size,
tp_rank,
dim=trt_llm_config.embedding_sharding_dim,
)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"proj_final.kernel",
"ffw_down.kernel",
"linear_out.kernel",
"o_proj.weight",
"down_proj.weight",
"linear_out.weight",
)):
if isinstance(ckpt_parser, JAXParser):
param = param.transpose(1, 0)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"linear_x.kernel",
"linear_y.kernel",
"linear_x.weight",
"linear_y.weight",
"up_proj.weight",
"gate_proj.weight",
)):
if isinstance(ckpt_parser, JAXParser):
param = param.transpose(1, 0)
param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"linear_x.bias",
"linear_y.bias",
"rg_lru",
"conv_1d.b",
"gate_proj.bias",
"up_proj.bias",
)):
param = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"channel_pre_norm",
"temporal_pre_norm",
"final_norm",
)):
param = param + 1.0
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"proj_final.bias",
"ffw_down.bias",
"linear_out.bias",
"o_proj.bias",
"down_proj.bias",
)):
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
else:
raise RuntimeError(f"Unhandled {name} module weights")
del model_params
print(
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
)
return weights
def convert(worker_rank, args, convert_kwargs):
for rank in range(worker_rank, args.world_size):
weights = convert_from_checkpoint(rank=rank, **convert_kwargs)
safetensors.torch.save_file(weights,
args.output_dir / f"rank{rank}.safetensors")
def main():
print(tensorrt_llm.__version__)
args = parse_arguments()
logger.set_level(args.log_level)
tik = time.time()
print(f"Loading source parameters from {args.model_dir.absolute()}")
ckpt_parser = CKPT_PARSER[args.ckpt_type]()
ckpt_params = ckpt_parser.load_parameters(args.model_dir)
input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
ckpt_params_dtype = str(input_embedding_weights.dtype).split(".")[-1]
ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params)
print(f"Source configuration determined from parameters: {ckpt_config}")
quant_config = tensorrt_llm.models.modeling_utils.QuantConfig()
trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig(
architecture="RecurrentGemmaForCausalLM",
dtype=args.dtype or ckpt_params_dtype,
logits_dtype="float32",
vocab_size=ckpt_config["vocab_size"],
# follow the setting of gemma models
max_position_embeddings=8192,
hidden_size=ckpt_config["hidden_size"],
num_hidden_layers=ckpt_config["num_hidden_layers"],
num_attention_heads=ckpt_config["num_attention_heads"],
num_key_value_heads=1,
head_size=ckpt_config["hidden_size"] //
ckpt_config["num_attention_heads"],
hidden_act="gelu",
intermediate_size=ckpt_config["intermediate_size"],
norm_epsilon=1e-6,
position_embedding_type="rope_gpt_neox",
world_size=args.world_size,
tp_size=args.world_size,
pp_size=1,
gpus_per_node=8,
quantization=quant_config,
conv_kernel=4,
state_size=1,
state_dtype='float32',
rotary_pct=0.5,
layer_types=ckpt_config["block_types"],
rnn_hidden_size=ckpt_config["lru_width"],
logits_soft_cap=ckpt_config["logits_soft_cap"],
emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"],
)
trt_llm_config_dict = trt_llm_config.to_dict()
print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")
config_path = args.output_dir / "config.json"
config_path.parent.mkdir(exist_ok=True, parents=True)
LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}")
with config_path.open("w") as config_file:
json.dump(trt_llm_config_dict, config_file, indent=4)
convert_args = dict(trt_llm_config=trt_llm_config,
model_dir=args.model_dir,
ckpt_parser=ckpt_parser)
convert(0, args, convert_args)
elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
print(f"Total time of converting checkpoints: {elapsed}")
if __name__ == "__main__":
main()