TensorRT-LLMs/examples/mamba/convert_checkpoint.py
Kaiyu Xie bca9a33b02
Update TensorRT-LLM (#2008)
* Update TensorRT-LLM

---------

Co-authored-by: Timur Abishev <abishev.timur@gmail.com>
Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai>
Co-authored-by: hattizai <hattizai@gmail.com>
2024-07-23 23:05:09 +08:00

407 lines
15 KiB
Python

import argparse
import copy
import json
import os
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union
import safetensors.torch
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils import CONFIG_NAME
from transformers.utils.hub import cached_file
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.models.convert_utils import (iterate_shard_files,
load_state_dict)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=Path, default=None)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument(
'--output_dir',
type=Path,
default='mamba_tllm_checkpoint',
help='The path to save the mamba TensorRT-LLM checkpoint')
parser.add_argument('--log_level', type=str, default='info')
args = parser.parse_args()
return args
def get_weight(config, prefix, dtype):
return config[prefix + '.weight'].to(dtype).detach()
def get_bias(config, prefix, dtype):
if (prefix + '.bias') in config:
return config[prefix + '.bias'].to(dtype).detach()
return None
def get_weight_and_bias(config, prefix, dtype_w, dtype_b):
return get_weight(config, prefix,
dtype_w), get_bias(config, prefix, dtype_b)
def get_tllm_linear_weight(weight, prefix, bias=None):
results = {}
results[prefix + 'weight'] = weight.contiguous()
if bias is not None:
results[prefix + 'bias'] = bias
return results
def convert_hf_mamba(hf_mamba,
rank=0,
dtype='float32',
mamba_version: str = 'Mamba1'):
weights = {}
tik = time.time()
model_params = dict(hf_mamba.named_parameters())
dtype = getattr(torch, dtype)
# Parameter names in mamba block
for l in range(hf_mamba.config.num_hidden_layers):
# ssm layer
prefix = f'backbone.layers.{l}.mixer.'
tllm_prex = f'backbone.layers.{l}.ssm.'
for layer in ['conv1d', 'x_proj', 'dt_proj', 'out_proj']:
dtype_b = torch.float32 if layer == 'dt_proj' else dtype
weight, bias = get_weight_and_bias(model_params, prefix + layer,
dtype, dtype_b)
if layer == 'conv1d':
weight = weight.unsqueeze(3)
tllm_weight_name = tllm_prex + layer + '.weight'
tllm_bias_name = tllm_prex + ('dt_bias' if layer == 'dt_proj' else
layer + '.bias')
weights[tllm_weight_name] = weight
if bias is not None:
weights[tllm_bias_name] = bias
# in_proj
weight, bias = get_weight_and_bias(model_params, prefix + 'in_proj',
dtype, dtype)
in_proj_weights = torch.split(weight, weight.size(0) // 2, dim=0)
tllm_weight_name = tllm_prex + 'in_proj.weight'
weights[tllm_weight_name.replace('proj', 'proj_x')] = in_proj_weights[0]
weights[tllm_weight_name.replace('proj', 'proj_z')] = in_proj_weights[1]
if bias is not None:
in_proj_biases = torch.split(bias, bias.size(0) // 2, dim=0)
tllm_bias_name = tllm_prex + 'in_proj.bias'
weights[tllm_bias_name.replace('proj',
'proj_x')] = in_proj_biases[0]
weights[tllm_bias_name.replace('proj',
'proj_x')] = in_proj_biases[1]
# A and D
Aparam = model_params[prefix + 'A_log'].float().detach()
Aparam = Aparam.permute(1, 0).contiguous()
weights[tllm_prex + 'A'] = -torch.exp(Aparam)
weights[tllm_prex + 'D'] = model_params[prefix + 'D'].float().detach()
# norm
prefix = f'backbone.layers.{l}.norm'
tllm_prex = f'backbone.layers.{l}.input_layernorm.'
weight, bias = get_weight_and_bias(model_params, prefix, dtype, dtype)
weights[tllm_prex + 'weight'] = weight
if bias is not None:
weights[tllm_prex + 'bias'] = bias
# others
for layer in ['backbone.embeddings', 'backbone.norm_f']:
weight, bias = get_weight_and_bias(model_params, layer, dtype, dtype)
layer = layer.replace('embeddings', 'vocab_embedding')
layer = layer.replace('norm_f', 'ln_f')
weights[layer + '.weight'] = weight
if bias is not None:
weights[layer + '.bias'] = bias
weights['lm_head.weight'], _ = get_weight_and_bias(model_params,
'backbone.embeddings',
dtype, dtype)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def rename_hf_to_tllm(name: str):
""" Rename a HF parameter name by the corresponding TRT-LLM style name. """
# remove model
if 'model.' in name:
name = name.replace('model.', '')
# change layer name
if 'embeddings.' in name:
name = name.replace('embeddings', 'vocab_embedding')
elif 'embedding.' in name:
name = name.replace('embedding', 'vocab_embedding')
norm_pattern = r'\d\.norm\.'
if 'mixer.' in name:
name = name.replace('mixer.', 'ssm.')
elif re.search(norm_pattern, name):
name = name.replace('norm.', 'input_layernorm.')
elif 'norm_f.' in name:
name = name.replace('norm_f.', 'ln_f.')
# Parameter names in ssm layers
if 'A_log' in name:
name = name.replace('A_log', 'A')
elif 'dt_proj.bias' in name:
name = name.replace('dt_proj.bias', 'dt_bias')
return name
def convert_from_hf_checkpoint(model_dir: Union[str, Path],
rank=0,
dtype: Union[str, torch.dtype] = torch.float32,
mamba_version: str = 'Mamba1'):
logger.info('Loading weights from HF Mamba...')
tik = time.time()
weights = {}
if isinstance(dtype, str):
dtype = tensorrt_llm.str_dtype_to_torch(dtype)
for model_file in iterate_shard_files(model_dir, 0):
logger.debug(f'Loading file {str(model_file)}...')
model_params = load_state_dict(model_file, dtype=dtype)
for name, param in model_params.items():
logger.debug(f'Converting weight {name}...')
tllm_name = rename_hf_to_tllm(name)
param = param.detach().cpu()
if 'A_log' in name:
param = -torch.exp(param.float())
if mamba_version == 'Mamba1':
param = param.permute(1, 0).contiguous()
elif 'D' in name:
param = param.float()
elif 'dt_proj.bias' in name:
param = param.float()
elif 'dt_bias' in name:
param = param.float()
elif 'conv1d.weight' in name:
param = param.unsqueeze(3)
# split in_proj in Mamba1
if 'in_proj' in name and mamba_version == 'Mamba1':
in_proj_params = torch.split(param, param.size(0) // 2, dim=0)
weights[tllm_name.replace('proj', 'proj_x')] = in_proj_params[0]
weights[tllm_name.replace('proj', 'proj_z')] = in_proj_params[1]
else:
weights[tllm_name] = param
del model_params
# lm_head
emb = weights['backbone.vocab_embedding.weight']
if 'lm_head.weight' not in weights or weights['lm_head.weight'].data_ptr(
) == emb.data_ptr():
weights['lm_head.weight'] = copy.deepcopy(emb)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return weights
def do_convert_from_ckpt(args):
return args.model_dir.exists()
def convert(worker_rank, args, convert_args):
convert_from_ckpt = do_convert_from_ckpt(args)
world_size = 1
args.workers = 1
for rank in range(worker_rank, world_size, args.workers):
if convert_from_ckpt:
weights = convert_from_hf_checkpoint(rank=rank, **convert_args)
else:
weights = convert_hf_mamba(rank=rank, **convert_args)
safetensors.torch.save_file(weights,
args.output_dir / f'rank{rank}.safetensors')
@dataclass
class MambaConfig:
architectures: List[str] = field(
default_factory=lambda: ['MambaForCausalLM'])
d_intermediate: int = 0
vocab_size: int = 50277
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
pad_vocab_size_multiple: int = 8
hidden_size: int = 2560
num_hidden_layers: int = 64
intermediate_size: int = 0
state_size: int = 128
conv_kernel: int = 4
use_bias: bool = False
headdim: int = 64
ngroups: int = 1
chunk_size: int = 256
ssm_rmsnorm: bool = True
def update(self, data_dict):
self.__dict__.update(data_dict)
def load_config_hf(model_name):
resolved_archive_file = cached_file(
model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
resolved_archive_file = os.path.join(model_name, 'params.json')
config = json.load(open(resolved_archive_file))
if 'transformers_version' in config: # transformer compatible models
hf_config = AutoConfig.from_pretrained(model_name,
trust_remote_code=True)
# TODO: change mamba_version when transformers can support Mamba2 models
mamba_version = 'Mamba1'
elif 'ssm_cfg' in config: # state-spaces/mamba models
ssm_cfg = config.pop('ssm_cfg')
cfg_to_mamba_cfg = {
'd_model': 'hidden_size',
'n_layer': 'num_hidden_layers',
'fused_add_norm': None,
'tie_embeddings': None,
}
ssm_cfg_to_mamba_cfg = {
'd_state': 'state_size',
'd_conv': 'conv_kernel',
'bias': 'use_bias',
'headdim': 'headdim',
'ngroups': 'ngroups',
'chunk_size': 'chunk_size',
'rmsnorm': 'ssm_rmsnorm',
}
for k in cfg_to_mamba_cfg:
if k in config:
v = config.pop(k)
if cfg_to_mamba_cfg[k] is not None:
config[cfg_to_mamba_cfg[k]] = v
for k in ssm_cfg_to_mamba_cfg:
if k in ssm_cfg and ssm_cfg_to_mamba_cfg[k] is not None:
config[ssm_cfg_to_mamba_cfg[k]] = ssm_cfg[k]
hf_config = MambaConfig(**config)
if 'expand' in ssm_cfg:
expand = ssm_cfg['expand']
hf_config.intermediate_size = expand * hf_config.hidden_size
else:
hf_config.intermediate_size = 2 * hf_config.hidden_size
mamba_version = ssm_cfg.pop("layer", "Mamba1")
else: # mistral format
cfg_to_mamba_cfg = {
'dim': 'hidden_size',
'n_layers': 'num_hidden_layers',
'n_groups': 'ngroups',
'fused_add_norm': None,
'tie_embeddings': None,
'model_type': None,
}
for k in cfg_to_mamba_cfg:
if k in config:
v = config.pop(k)
if cfg_to_mamba_cfg[k] is not None:
config[cfg_to_mamba_cfg[k]] = v
hf_config = MambaConfig(**config)
if 'expand' in config:
expand = config['expand']
hf_config.intermediate_size = expand * hf_config.hidden_size
else:
hf_config.intermediate_size = 2 * hf_config.hidden_size
mamba_version = 'Mamba2'
return hf_config, mamba_version
def main():
print(tensorrt_llm.__version__)
args = parse_arguments()
logger.set_level(args.log_level)
tik = time.time()
args.output_dir.mkdir(exist_ok=True, parents=True)
hf_config, mamba_version = load_config_hf(args.model_dir)
vocab_size = hf_config.vocab_size
pad_vocab_size_multiple = hf_config.pad_vocab_size_multiple
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size %
pad_vocab_size_multiple)
config = {
'architecture': hf_config.architectures[0],
'dtype': args.dtype,
'logits_dtype': 'float32',
'hidden_size': hf_config.hidden_size,
'num_hidden_layers': hf_config.num_hidden_layers,
'layer_types': ['recurrent'],
'vocab_size': vocab_size,
'rms_norm': hf_config.rms_norm,
'residual_in_fp32': hf_config.residual_in_fp32,
'pad_vocab_size_multiple': hf_config.pad_vocab_size_multiple,
'hidden_act': 'silu',
'num_attention_heads': 1,
'rnn_hidden_size': hf_config.intermediate_size,
'rnn_conv_dim_size': hf_config.intermediate_size,
'state_size': hf_config.state_size,
'conv_kernel': hf_config.conv_kernel,
'use_bias': hf_config.use_bias,
'mamba_version': mamba_version,
}
if mamba_version == 'Mamba2':
conv_dim = hf_config.intermediate_size + 2 * hf_config.ngroups * hf_config.state_size
mamba2_cfg = {
'rnn_head_size': hf_config.headdim,
'rnn_conv_dim_size': conv_dim,
'ngroups': hf_config.ngroups,
'chunk_size': hf_config.chunk_size,
'ssm_rmsnorm': hf_config.ssm_rmsnorm,
}
config.update(mamba2_cfg)
with (args.output_dir / 'config.json').open('w') as f:
json.dump(config, f, indent=4)
convert_from_ckpt = do_convert_from_ckpt(args)
# TODO: Add convert_hf_mamba support for Mamba2 when transformers can support Mamba2 models
assert convert_from_ckpt or mamba_version == 'Mamba2', "Mamba2 can only support convert from checkpoints."
if not convert_from_ckpt:
logger.info(f'Convert by using model')
hf_mamba = AutoModelForCausalLM.from_pretrained(args.model_dir,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True)
else:
logger.info(f'Convert by using checkpoint')
hf_mamba = None
convert_args = dict(dtype=args.dtype, )
if convert_from_ckpt:
convert_args['model_dir'] = args.model_dir
else:
convert_args['hf_mamba'] = hf_mamba
convert_args['mamba_version'] = mamba_version
convert(0, args, convert_args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')
if __name__ == '__main__':
main()