import argparse import json import logging import math import re import time import typing from pathlib import Path import flax import numpy as np import orbax import safetensors.torch import torch from recurrentgemma import jax as recurrentgemma_jax from transformers import AutoConfig, AutoModelForCausalLM import tensorrt_llm from tensorrt_llm import logger from tensorrt_llm._utils import (numpy_to_torch, str_dtype_to_torch, torch_to_numpy) LOGGER = logging.getLogger("convert_checkpoint") def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt_type", type=str, choices=["jax", "hf"]) parser.add_argument("--model_dir", type=Path, default=None) parser.add_argument("--world_size", type=int, default=1, help="world size, only support tensor parallelism now") parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"]) parser.add_argument( "--output_dir", type=Path, default="recurrentgemma_tllm_checkpoint", help="The path to save the recurrentgemma TensorRT-LLM checkpoint") parser.add_argument("--log_level", type=str, default="info") args = parser.parse_args() return args class JAXParser: def load_parameters(self, checkpoint_path: Path): checkpoint_path = checkpoint_path.absolute() checkpointer = orbax.checkpoint.PyTreeCheckpointer() params = checkpointer.restore(checkpoint_path) return params def embedding_weights(self, ckpt_params): return ckpt_params["embedder"]["input_embedding"] def get_config(self, checkpoint_path, ckpt_params): config = recurrentgemma_jax.GriffinConfig.from_flax_params_or_variables( ckpt_params)._asdict() if config["lru_width"] is None: config["lru_width"] = config["width"] layer_types = [] for p in config["block_types"]: if p == recurrentgemma_jax.TemporalBlockType.ATTENTION: layer_types.append("attention") else: layer_types.append("recurrent") config["block_types"] = layer_types config["hidden_size"] = config.pop("width") config["num_attention_heads"] = config.pop("num_heads") config["intermediate_size"] = config.pop("mlp_expanded_width") config["num_hidden_layers"] = len(config["block_types"]) return config def rename_to_trt_llm(self, name: str): """Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name.""" sub_patterns = ( (r"embedder.input_embedding", r"vocab_embedding.weight"), (r"blocks.(\d+).channel_pre_norm.scale", r"layers.\1.post_layernorm.weight"), (r"blocks.(\d+).temporal_pre_norm.scale", r"layers.\1.input_layernorm.weight"), (r"blocks.(\d+).recurrent_block.conv_1d.w", r"layers.\1.temporal_block.conv1d.weight"), (r"blocks.(\d+).recurrent_block.conv_1d.b", r"layers.\1.temporal_block.conv1d.bias"), (r"blocks.(\d+).recurrent_block.linear_out.kernel", r"layers.\1.temporal_block.linear_out.weight"), (r"blocks.(\d+).recurrent_block.linear_out.bias", r"layers.\1.temporal_block.linear_out.bias"), (r"blocks.(\d+).recurrent_block.linear_x.kernel", r"layers.\1.temporal_block.linear_x.weight"), (r"blocks.(\d+).recurrent_block.linear_x.bias", r"layers.\1.temporal_block.linear_x.bias"), (r"blocks.(\d+).recurrent_block.linear_y.kernel", r"layers.\1.temporal_block.linear_y.weight"), (r"blocks.(\d+).recurrent_block.linear_y.bias", r"layers.\1.temporal_block.y_bias"), (r"blocks.(\d+).recurrent_block.rg_lru.a_gate.w", r"layers.\1.temporal_block.a_gate.weight"), (r"blocks.(\d+).recurrent_block.rg_lru.a_gate.b", r"layers.\1.temporal_block.a_gate.bias"), (r"blocks.(\d+).recurrent_block.rg_lru.input_gate.w", r"layers.\1.temporal_block.input_gate.weight"), (r"blocks.(\d+).recurrent_block.rg_lru.input_gate.b", r"layers.\1.temporal_block.input_gate.bias"), (r"blocks.(\d+).recurrent_block.rg_lru.a_param", r"layers.\1.temporal_block.A"), (r"blocks.(\d+).mlp_block.ffw_up.w", r"layers.\1.mlp.fc.weight"), (r"blocks.(\d+).mlp_block.ffw_up.b", None), (r"blocks.(\d+).mlp_block.ffw_down.kernel", r"layers.\1.mlp.proj.weight"), (r"blocks.(\d+).mlp_block.ffw_down.bias", r"layers.\1.mlp.proj.bias"), (r"blocks.(\d+).attention_block.proj_q.kernel", r"layers.\1.temporal_block.qkv.weight"), (r"blocks.(\d+).attention_block.proj_k.kernel", None), (r"blocks.(\d+).attention_block.proj_v.kernel", None), (r"blocks.(\d+).attention_block.proj_final.kernel", r"layers.\1.temporal_block.dense.weight"), (r"blocks.(\d+).attention_block.proj_final.bias", r"layers.\1.temporal_block.dense.bias"), (r"final_norm.scale", r"norm_f.weight"), ) for source, target in sub_patterns: if re.match(source, name): if target is None: return target else: name = re.sub(source, target, name) return ".".join(("backbone", name)) else: raise ValueError(f"Don't know how to rename {name}") def flatten_params(self, params): new_params = flax.traverse_util.flatten_dict(params, sep=".") # if the dtype is bfloat16, cast to float32 for k in new_params: if new_params[k].dtype != np.float32 and new_params[ k].dtype != np.float16: new_params[k] = new_params[k].astype(np.float32) return new_params class HfParser: def load_parameters(self, checkpoint_path: Path): hf_model = AutoModelForCausalLM.from_pretrained( checkpoint_path, device_map="auto", torch_dtype="auto", ) model_params = dict(hf_model.named_parameters()) return model_params def embedding_weights(self, ckpt_params): return ckpt_params["model.embed_tokens.weight"] def get_config(self, checkpoint_path, ckpt_params): checkpoint_path = checkpoint_path.absolute() hf_config = AutoConfig.from_pretrained( checkpoint_path, trust_remote_code=True).to_dict() hf_config["block_types"] = hf_config.pop("_block_types") hf_config["intermediate_size"] = hf_config["intermediate_size"] // 2 return hf_config def rename_to_trt_llm(self, name: str): """Rename a recurrentgemma parameter name by the corresponding TRT-LLM style name.""" sub_patterns = ( (r"model.embed_tokens.weight", r"vocab_embedding.weight"), (r"model.layers.(\d+).temporal_pre_norm.weight", r"layers.\1.input_layernorm.weight"), (r"model.layers.(\d+).channel_pre_norm.weight", r"layers.\1.post_layernorm.weight"), (r"model.layers.(\d+).temporal_block.conv_1d.weight", r"layers.\1.temporal_block.conv1d.weight"), (r"model.layers.(\d+).temporal_block.conv_1d.bias", r"layers.\1.temporal_block.conv1d.bias"), (r"model.layers.(\d+).temporal_block.linear_out.weight", r"layers.\1.temporal_block.linear_out.weight"), (r"model.layers.(\d+).temporal_block.linear_out.bias", r"layers.\1.temporal_block.linear_out.bias"), (r"model.layers.(\d+).temporal_block.linear_x.weight", r"layers.\1.temporal_block.linear_x.weight"), (r"model.layers.(\d+).temporal_block.linear_x.bias", r"layers.\1.temporal_block.linear_x.bias"), (r"model.layers.(\d+).temporal_block.linear_y.weight", r"layers.\1.temporal_block.linear_y.weight"), (r"model.layers.(\d+).temporal_block.linear_y.bias", r"layers.\1.temporal_block.y_bias"), (r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_weight", r"layers.\1.temporal_block.a_gate.weight"), (r"model.layers.(\d+).temporal_block.rg_lru.recurrent_gate_bias", r"layers.\1.temporal_block.a_gate.bias"), (r"model.layers.(\d+).temporal_block.rg_lru.input_gate_weight", r"layers.\1.temporal_block.input_gate.weight"), (r"model.layers.(\d+).temporal_block.rg_lru.input_gate_bias", r"layers.\1.temporal_block.input_gate.bias"), (r"model.layers.(\d+).temporal_block.rg_lru.recurrent_param", r"layers.\1.temporal_block.A"), (r"model.layers.(\d+).mlp_block.up_proj.weight", r"layers.\1.mlp.gate.weight"), (r"model.layers.(\d+).mlp_block.up_proj.bias", r"layers.\1.mlp.gate.bias"), (r"model.layers.(\d+).mlp_block.gate_proj.weight", r"layers.\1.mlp.fc.weight"), (r"model.layers.(\d+).mlp_block.gate_proj.bias", r"layers.\1.mlp.fc.bias"), (r"model.layers.(\d+).mlp_block.down_proj.weight", r"layers.\1.mlp.proj.weight"), (r"model.layers.(\d+).mlp_block.down_proj.bias", r"layers.\1.mlp.proj.bias"), (r"model.layers.(\d+).temporal_block.q_proj.weight", r"layers.\1.temporal_block.qkv.weight"), (r"model.layers.(\d+).temporal_block.k_proj.weight", None), (r"model.layers.(\d+).temporal_block.v_proj.weight", None), (r"model.layers.(\d+).temporal_block.o_proj.weight", r"layers.\1.temporal_block.dense.weight"), (r"model.layers.(\d+).temporal_block.o_proj.bias", r"layers.\1.temporal_block.dense.bias"), (r"model.final_norm.weight", r"norm_f.weight"), ) for source, target in sub_patterns: if re.match(source, name): if target is None: return target else: name = re.sub(source, target, name) return ".".join(("backbone", name)) else: raise ValueError(f"Don't know how to rename {name}") def flatten_params(self, params): f_params = {} for k, v in params.items(): if v.dtype == torch.bfloat16: v = v.float() f_params[k] = torch_to_numpy(v) return f_params CKPT_PARSER = { "jax": JAXParser, "hf": HfParser, } def split(v, tp_size, idx, dim=0): if tp_size == 1: return v return np.split(v, tp_size, axis=dim)[idx] def split_matrix_tp(v, tensor_parallel, rank, dim): return split(v, tensor_parallel, rank, dim=dim) def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray], name: str, param: np.ndarray, dtype: typing.Optional[np.dtype] = None): assert name not in weights, f"{name} is already added." param = numpy_to_torch(param) if dtype is not None: assert isinstance(dtype, str), f"dtype must be str, but get type {type(dtype)}" param = param.to(str_dtype_to_torch(dtype)) weights[name] = param.contiguous() def convert_from_checkpoint( trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig, model_dir: typing.Union[str, Path], ckpt_parser, rank=0, ): print("Loading weights...") tik = time.time() tp_rank = rank tp_size = trt_llm_config.mapping.tp_size intermediate_size = trt_llm_config.intermediate_size rnn_hidden_size = trt_llm_config.rnn_hidden_size conv_kernel = trt_llm_config.conv_kernel weights = {} for model_file in [model_dir]: LOGGER.debug(f"Loading directory {str(model_file)}...") model_params = ckpt_parser.load_parameters(model_file) model_params = ckpt_parser.flatten_params(model_params) for name, param in model_params.items(): LOGGER.debug(f"Converting weight {name}...") trt_llm_name = ckpt_parser.rename_to_trt_llm(name) if trt_llm_name is None: # omit as used with other params continue if "proj_q" in name or "q_proj" in name: if isinstance(ckpt_parser, JAXParser): k_name = name.replace("proj_q", "proj_k") v_name = name.replace("proj_q", "proj_v") q_param = param.transpose(1, 0) k_param = model_params[k_name].transpose(1, 0) v_param = model_params[v_name].transpose(1, 0) else: k_name = name.replace("q_proj", "k_proj") v_name = name.replace("q_proj", "v_proj") q_param = param k_param = model_params[k_name] v_param = model_params[v_name] q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0) qkv_param = np.concatenate([q_param, k_param, v_param], axis=0) add_trt_llm_weight(weights, trt_llm_name, qkv_param, trt_llm_config.dtype) elif "ffw_up.w" in name and isinstance(ckpt_parser, JAXParser): bias_name = name.replace("ffw_up.w", "ffw_up.b") fc_param, gate_param = param[0, ].transpose( 1, 0), param[1, ].transpose(1, 0) fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0) gate_param = split_matrix_tp(gate_param, tp_size, tp_rank, dim=0) fc_bias = model_params[bias_name][0, ].reshape( intermediate_size) gate_bias = model_params[bias_name][1, ].reshape( intermediate_size) trt_llm_fc_name = trt_llm_name trt_llm_gate_name = trt_llm_name.replace( "fc.weight", "gate.weight") trt_llm_fc_b_name = trt_llm_name.replace("fc.weight", "fc.bias") trt_llm_gate_b_name = trt_llm_name.replace( "fc.weight", "gate.bias") add_trt_llm_weight(weights, trt_llm_fc_name, fc_param, trt_llm_config.dtype) add_trt_llm_weight(weights, trt_llm_gate_name, gate_param, trt_llm_config.dtype) add_trt_llm_weight(weights, trt_llm_fc_b_name, fc_bias, trt_llm_config.dtype) add_trt_llm_weight(weights, trt_llm_gate_b_name, gate_bias, trt_llm_config.dtype) elif "conv_1d.w" in name: if isinstance(ckpt_parser, JAXParser): param = param.transpose(1, 0).reshape(rnn_hidden_size, 1, conv_kernel, 1) else: param = param.reshape(rnn_hidden_size, 1, conv_kernel, 1) param = split_matrix_tp(param, tp_size, tp_rank, dim=0) add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) elif "embedder.input_embedding" in name or "model.embed_tokens" in name: if not trt_llm_config.share_embedding_table: lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0) add_trt_llm_weight(weights, "lm_head.weight", np.copy(lm_head), trt_llm_config.dtype) if trt_llm_config.emb_scale_by_sqrt_dim: param = np.multiply( param.astype(np.float32), math.sqrt(trt_llm_config.hidden_size), ) if trt_llm_config.use_parallel_embedding: assert trt_llm_config.vocab_size % tp_size == 0 param = split_matrix_tp( param, tp_size, tp_rank, dim=trt_llm_config.embedding_sharding_dim, ) add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) elif any(keyword in name for keyword in ( "proj_final.kernel", "ffw_down.kernel", "linear_out.kernel", "o_proj.weight", "up_proj.weight", "gate_proj.weight", "down_proj.weight", "linear_out.weight", )): if isinstance(ckpt_parser, JAXParser): param = param.transpose(1, 0) param = split_matrix_tp(param, tp_size, tp_rank, dim=0) add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) elif any(keyword in name for keyword in ( "linear_x.kernel", "linear_y.kernel", "linear_x.weight", "linear_y.weight", )): if isinstance(ckpt_parser, JAXParser): param = param.transpose(1, 0) param = split_matrix_tp(param, tp_size, tp_rank, dim=1) add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) elif any(keyword in name for keyword in ( "linear_x.bias", "linear_y.bias", "rg_lru", "conv_1d.b", )): param = split_matrix_tp(param, tp_size, tp_rank, dim=0) add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) elif any(keyword in name for keyword in ( "channel_pre_norm", "temporal_pre_norm", "final_norm", )): param = param + 1.0 add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) elif any(keyword in name for keyword in ( "proj_final.bias", "ffw_down.bias", "linear_out.bias", "o_proj.bias", "up_proj.bias", "gate_proj.bias", "down_proj.bias", )): add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) else: raise RuntimeError(f"Unhandled {name} module weights") del model_params print( f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}" ) return weights def convert(worker_rank, args, convert_kwargs): for rank in range(worker_rank, args.world_size): weights = convert_from_checkpoint(rank=rank, **convert_kwargs) safetensors.torch.save_file(weights, args.output_dir / f"rank{rank}.safetensors") def main(): print(tensorrt_llm.__version__) args = parse_arguments() logger.set_level(args.log_level) tik = time.time() print(f"Loading source parameters from {args.model_dir.absolute()}") ckpt_parser = CKPT_PARSER[args.ckpt_type]() ckpt_params = ckpt_parser.load_parameters(args.model_dir) input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params) ckpt_params_dtype = str(input_embedding_weights.dtype).split(".")[-1] ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params) print(f"Source configuration determined from parameters: {ckpt_config}") quant_config = tensorrt_llm.models.modeling_utils.QuantConfig() trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig( architecture="RecurrentGemmaForCausalLM", dtype=args.dtype or ckpt_params_dtype, logits_dtype="float32", vocab_size=ckpt_config["vocab_size"], max_position_embeddings=8192, hidden_size=ckpt_config["hidden_size"], num_hidden_layers=ckpt_config["num_hidden_layers"], num_attention_heads=ckpt_config["num_attention_heads"], num_key_value_heads=1, head_size=ckpt_config["hidden_size"] // ckpt_config["num_attention_heads"], hidden_act="gelu", intermediate_size=ckpt_config["intermediate_size"], norm_epsilon=1e-6, position_embedding_type="rope_gpt_neox", world_size=args.world_size, tp_size=args.world_size, pp_size=1, gpus_per_node=8, quantization=quant_config, conv_kernel=4, rotary_percentage=0.5, layer_types=ckpt_config["block_types"], rnn_hidden_size=ckpt_config["lru_width"], logits_soft_cap=ckpt_config["logits_soft_cap"], emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"], ) trt_llm_config_dict = trt_llm_config.to_dict() print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}") config_path = args.output_dir / "config.json" config_path.parent.mkdir(exist_ok=True, parents=True) LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}") with config_path.open("w") as config_file: json.dump(trt_llm_config_dict, config_file, indent=4) convert_args = dict(trt_llm_config=trt_llm_config, model_dir=args.model_dir, ckpt_parser=ckpt_parser) convert(0, args, convert_args) elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik)) print(f"Total time of converting checkpoints: {elapsed}") if __name__ == "__main__": main()