TensorRT-LLMs/examples/mamba/convert_checkpoint.py
Kaiyu Xie 31ac30e928
Update TensorRT-LLM (#2215)
* Update TensorRT-LLM

---------

Co-authored-by: Sherlock Xu <65327072+Sherlock113@users.noreply.github.com>
2024-09-10 18:21:22 +08:00

479 lines
18 KiB
Python

import argparse
import copy
import json
import os
import re
import time
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import List, Union
import safetensors.torch
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.models.convert_utils import (iterate_shard_files,
load_state_dict)
class CheckpointType(str, Enum):
mistral_inference = "mistral_inference"
state_spaces = "state_spaces"
hf = "hf"
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_type",
type=CheckpointType,
choices=list(CheckpointType),
default=CheckpointType.hf,
help='Checkpoint type')
parser.add_argument('--model_dir', type=Path, default=None)
parser.add_argument("--world_size",
type=int,
default=1,
help="world size, only support tensor parallelism now")
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument(
'--output_dir',
type=Path,
default='mamba_tllm_checkpoint',
help='The path to save the mamba TensorRT-LLM checkpoint')
parser.add_argument('--log_level', type=str, default='info')
args = parser.parse_args()
return args
def get_weight(config, prefix, dtype):
return config[prefix + '.weight'].to(dtype).detach()
def get_bias(config, prefix, dtype):
if (prefix + '.bias') in config:
return config[prefix + '.bias'].to(dtype).detach()
return None
def get_weight_and_bias(config, prefix, dtype_w, dtype_b):
return get_weight(config, prefix,
dtype_w), get_bias(config, prefix, dtype_b)
def get_tllm_linear_weight(weight, prefix, bias=None):
results = {}
results[prefix + 'weight'] = weight.contiguous()
if bias is not None:
results[prefix + 'bias'] = bias
return results
def split(v, tp_size, idx, dim=0):
assert v.shape[dim] % tp_size == 0
split_size = v.shape[dim] // tp_size
if tp_size == 1:
return v
return torch.split(v, split_size, dim=dim)[idx]
def convert_hf_mamba(hf_mamba,
rank=0,
dtype='float32',
mamba_version: str = 'Mamba1'):
weights = {}
tik = time.time()
model_params = dict(hf_mamba.named_parameters())
dtype = getattr(torch, dtype)
# Parameter names in mamba block
for l in range(hf_mamba.config.num_hidden_layers):
# ssm layer
prefix = f'backbone.layers.{l}.mixer.'
tllm_prex = f'backbone.layers.{l}.ssm.'
for layer in ['conv1d', 'x_proj', 'dt_proj', 'out_proj']:
dtype_b = torch.float32 if layer == 'dt_proj' else dtype
weight, bias = get_weight_and_bias(model_params, prefix + layer,
dtype, dtype_b)
if layer == 'conv1d':
weight = weight.unsqueeze(3)
tllm_weight_name = tllm_prex + layer + '.weight'
tllm_bias_name = tllm_prex + ('dt_bias' if layer == 'dt_proj' else
layer + '.bias')
weights[tllm_weight_name] = weight
if bias is not None:
weights[tllm_bias_name] = bias
# in_proj
weight, bias = get_weight_and_bias(model_params, prefix + 'in_proj',
dtype, dtype)
in_proj_weights = torch.split(weight, weight.size(0) // 2, dim=0)
tllm_weight_name = tllm_prex + 'in_proj.weight'
weights[tllm_weight_name.replace('proj', 'proj_x')] = in_proj_weights[0]
weights[tllm_weight_name.replace('proj', 'proj_z')] = in_proj_weights[1]
if bias is not None:
in_proj_biases = torch.split(bias, bias.size(0) // 2, dim=0)
tllm_bias_name = tllm_prex + 'in_proj.bias'
weights[tllm_bias_name.replace('proj',
'proj_x')] = in_proj_biases[0]
weights[tllm_bias_name.replace('proj',
'proj_x')] = in_proj_biases[1]
# A and D
Aparam = model_params[prefix + 'A_log'].float().detach()
Aparam = Aparam.permute(1, 0).contiguous()
weights[tllm_prex + 'A'] = -torch.exp(Aparam)
weights[tllm_prex + 'D'] = model_params[prefix + 'D'].float().detach()
# norm
prefix = f'backbone.layers.{l}.norm'
tllm_prex = f'backbone.layers.{l}.input_layernorm.'
weight, bias = get_weight_and_bias(model_params, prefix, dtype, dtype)
weights[tllm_prex + 'weight'] = weight
if bias is not None:
weights[tllm_prex + 'bias'] = bias
# others
for layer in ['backbone.embeddings', 'backbone.norm_f']:
weight, bias = get_weight_and_bias(model_params, layer, dtype, dtype)
layer = layer.replace('embeddings', 'vocab_embedding')
layer = layer.replace('norm_f', 'ln_f')
weights[layer + '.weight'] = weight
if bias is not None:
weights[layer + '.bias'] = bias
weights['lm_head.weight'], _ = get_weight_and_bias(model_params,
'backbone.embeddings',
dtype, dtype)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def rename_hf_to_tllm(name: str):
""" Rename a HF parameter name by the corresponding TRT-LLM style name. """
# remove model
if 'model.' in name:
name = name.replace('model.', '')
# change layer name
if 'embeddings.' in name:
name = name.replace('embeddings', 'vocab_embedding')
elif 'embedding.' in name:
name = name.replace('embedding', 'vocab_embedding')
norm_pattern = r'\d\.norm\.'
if 'mixer.' in name:
name = name.replace('mixer.', 'ssm.')
elif re.search(norm_pattern, name):
name = name.replace('norm.', 'input_layernorm.')
elif 'norm_f.' in name:
name = name.replace('norm_f.', 'ln_f.')
# Parameter names in ssm layers
if 'A_log' in name:
name = name.replace('A_log', 'A')
elif 'dt_proj.bias' in name:
name = name.replace('dt_proj.bias', 'dt_bias')
return name
def convert_from_hf_checkpoint(mamba_config: dict,
model_dir: Union[str, Path],
rank=0,
dtype: Union[str, torch.dtype] = torch.float32,
mamba_version: str = 'Mamba1'):
logger.info('Loading weights from HF Mamba...')
tik = time.time()
tp_rank = rank
tp_size = mamba_config['mapping']['tp_size']
d_inner = mamba_config['rnn_hidden_size']
d_state = mamba_config['state_size']
weights = {}
if isinstance(dtype, str):
dtype = tensorrt_llm.str_dtype_to_torch(dtype)
for model_file in iterate_shard_files(model_dir, 0):
logger.debug(f'Loading file {str(model_file)}...')
model_params = load_state_dict(model_file, dtype=dtype)
for name, param in model_params.items():
logger.debug(f'Converting weight {name}...')
tllm_name = rename_hf_to_tllm(name)
param = param.detach().cpu()
if 'A_log' in name:
param = -torch.exp(param.float())
if mamba_version == 'Mamba1':
param = param.permute(1, 0).contiguous()
elif 'D' in name:
param = param.float()
elif 'dt_proj.bias' in name:
param = param.float()
elif 'dt_bias' in name:
param = param.float()
elif 'conv1d.weight' in name:
param = param.unsqueeze(3)
# split in_proj in Mamba1
if 'in_proj' in name and mamba_version == 'Mamba1':
in_proj_params = torch.split(param, param.size(0) // 2, dim=0)
weights[tllm_name.replace('proj', 'proj_x')] = in_proj_params[0]
weights[tllm_name.replace('proj', 'proj_z')] = in_proj_params[1]
elif 'in_proj' in name and mamba_version == 'Mamba2':
nheads = d_inner // mamba_config['rnn_head_size']
ngroups = mamba_config['ngroups']
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
param, [
d_inner, d_inner, ngroups * d_state, ngroups * d_state,
nheads
],
dim=0)
in_proj_z = split(in_proj_z, tp_size, tp_rank, dim=0)
in_proj_x = split(in_proj_x, tp_size, tp_rank, dim=0)
in_proj_b = split(in_proj_b, tp_size, tp_rank, dim=0)
in_proj_c = split(in_proj_c, tp_size, tp_rank, dim=0)
in_proj_dt = split(in_proj_dt, tp_size, tp_rank, dim=0)
in_proj = torch.concat(
[in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt])
weights[tllm_name] = in_proj.contiguous()
elif 'conv1d' in name and mamba_version == 'Mamba2':
ngroups = mamba_config['ngroups']
conv_x, conv_b, conv_c = torch.split(
param, [d_inner, ngroups * d_state, ngroups * d_state],
dim=0)
conv_x = split(conv_x, tp_size, tp_rank, dim=0)
conv_b = split(conv_b, tp_size, tp_rank, dim=0)
conv_c = split(conv_c, tp_size, tp_rank, dim=0)
conv = torch.concat([conv_x, conv_b, conv_c])
weights[tllm_name] = conv.contiguous()
elif any(keyword in name for keyword in (
'mixer.norm.weight',
'A_log',
'D',
'dt_proj.bias',
'dt_bias',
)) and mamba_version == 'Mamba2':
weights[tllm_name] = split(param, tp_size, tp_rank, dim=0)
elif 'out_proj' in name and mamba_version == 'Mamba2':
weights[tllm_name] = split(param, tp_size, tp_rank,
dim=1).contiguous()
else:
weights[tllm_name] = param
del model_params
# lm_head
emb = weights['backbone.vocab_embedding.weight']
if 'lm_head.weight' not in weights or weights['lm_head.weight'].data_ptr(
) == emb.data_ptr():
weights['lm_head.weight'] = copy.deepcopy(emb)
if mamba_version == 'Mamba2':
weights['lm_head.weight'] = split(weights['lm_head.weight'],
tp_size,
tp_rank,
dim=0)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return weights
def do_convert_from_ckpt(args):
return args.model_dir.exists()
def convert(worker_rank, args, convert_args):
convert_from_ckpt = do_convert_from_ckpt(args)
for rank in range(worker_rank, args.world_size):
if convert_from_ckpt:
weights = convert_from_hf_checkpoint(rank=rank, **convert_args)
else:
weights = convert_hf_mamba(rank=rank, **convert_args)
safetensors.torch.save_file(weights,
args.output_dir / f'rank{rank}.safetensors')
@dataclass
class MambaConfig:
architectures: List[str] = field(
default_factory=lambda: ['MambaForCausalLM'])
d_intermediate: int = 0
vocab_size: int = 50277
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
pad_vocab_size_multiple: int = 8
hidden_size: int = 2560
num_hidden_layers: int = 64
intermediate_size: int = 0
state_size: int = 128
conv_kernel: int = 4
use_bias: bool = False
head_dim: int = 64
n_groups: int = 1
chunk_size: int = 256
ssm_rmsnorm: bool = True
def update(self, data_dict):
self.__dict__.update(data_dict)
def load_config_hf(model_name, ckpt_type):
if ckpt_type == CheckpointType.hf: # transformer compatible models
hf_config = AutoConfig.from_pretrained(model_name,
trust_remote_code=True)
mamba_version = 'Mamba2' if hf_config.model_type == 'mamba2' else 'Mamba1'
elif ckpt_type == CheckpointType.state_spaces: # state-spaces/mamba models
config = json.load(open(os.path.join(model_name, 'config.json')))
ssm_cfg = config.pop('ssm_cfg')
cfg_to_mamba_cfg = {
'd_model': 'hidden_size',
'n_layer': 'num_hidden_layers',
'fused_add_norm': None,
'tie_embeddings': None,
}
ssm_cfg_to_mamba_cfg = {
'd_state': 'state_size',
'd_conv': 'conv_kernel',
'bias': 'use_bias',
'headdim': 'head_dim',
'ngroups': 'n_groups',
'chunk_size': 'chunk_size',
'rmsnorm': 'ssm_rmsnorm',
}
for k in cfg_to_mamba_cfg:
if k in config:
v = config.pop(k)
if cfg_to_mamba_cfg[k] is not None:
config[cfg_to_mamba_cfg[k]] = v
for k in ssm_cfg_to_mamba_cfg:
if k in ssm_cfg and ssm_cfg_to_mamba_cfg[k] is not None:
config[ssm_cfg_to_mamba_cfg[k]] = ssm_cfg[k]
hf_config = MambaConfig(**config)
if 'expand' in ssm_cfg:
expand = ssm_cfg['expand']
hf_config.intermediate_size = expand * hf_config.hidden_size
else:
hf_config.intermediate_size = 2 * hf_config.hidden_size
mamba_version = ssm_cfg.pop("layer", "Mamba1")
elif ckpt_type == CheckpointType.mistral_inference: # mistral inference format
config = json.load(open(os.path.join(model_name, 'params.json')))
cfg_to_mamba_cfg = {
'dim': 'hidden_size',
'n_layers': 'num_hidden_layers',
'n_groups': 'n_groups',
'fused_add_norm': None,
'tie_embeddings': None,
'model_type': None,
}
for k in cfg_to_mamba_cfg:
if k in config:
v = config.pop(k)
if cfg_to_mamba_cfg[k] is not None:
config[cfg_to_mamba_cfg[k]] = v
hf_config = MambaConfig(**config)
if 'expand' in config:
expand = config['expand']
hf_config.intermediate_size = expand * hf_config.hidden_size
else:
hf_config.intermediate_size = 2 * hf_config.hidden_size
mamba_version = 'Mamba2'
return hf_config, mamba_version
def main():
print(tensorrt_llm.__version__)
args = parse_arguments()
logger.set_level(args.log_level)
tik = time.time()
args.output_dir.mkdir(exist_ok=True, parents=True)
hf_config, mamba_version = load_config_hf(args.model_dir, args.ckpt_type)
vocab_size = hf_config.vocab_size
pad_vocab_size_multiple = getattr(hf_config, "pad_vocab_size_multiple", 1)
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size %
pad_vocab_size_multiple)
config = {
'architecture': 'MambaForCausalLM',
'dtype': args.dtype,
'logits_dtype': 'float32',
'hidden_size': hf_config.hidden_size,
'num_hidden_layers': hf_config.num_hidden_layers,
'layer_types': ['recurrent'],
'vocab_size': vocab_size,
'rms_norm': hf_config.rms_norm,
'residual_in_fp32': hf_config.residual_in_fp32,
'pad_vocab_size_multiple': pad_vocab_size_multiple,
'hidden_act': 'silu',
'num_attention_heads': args.world_size,
'rnn_hidden_size': hf_config.intermediate_size,
'rnn_conv_dim_size': hf_config.intermediate_size,
'state_size': hf_config.state_size,
'conv_kernel': hf_config.conv_kernel,
'use_bias': hf_config.use_bias,
'mamba_version': mamba_version,
'mapping': {
'world_size': args.world_size,
'tp_size': args.world_size,
'pp_size': 1
},
}
if mamba_version == 'Mamba2':
conv_dim = hf_config.intermediate_size + 2 * hf_config.n_groups * hf_config.state_size
ssm_rmsnorm = getattr(hf_config, "ssm_rmsnorm", hf_config.rms_norm)
mamba2_cfg = {
'rnn_head_size': hf_config.head_dim,
'rnn_conv_dim_size': conv_dim,
'ngroups': hf_config.n_groups,
'chunk_size': hf_config.chunk_size,
'ssm_rmsnorm': ssm_rmsnorm,
}
config.update(mamba2_cfg)
with (args.output_dir / 'config.json').open('w') as f:
json.dump(config, f, indent=4)
convert_from_ckpt = do_convert_from_ckpt(args)
# TODO: Add convert_hf_mamba support for Mamba2 when transformers can support Mamba2 models
assert convert_from_ckpt or mamba_version == 'Mamba2', "Mamba2 can only support convert from checkpoints."
assert args.world_size == 1 or mamba_version == 'Mamba2', "Mamba1 can not support tensor parallelism."
if not convert_from_ckpt:
logger.info(f'Convert by using model')
hf_mamba = AutoModelForCausalLM.from_pretrained(args.model_dir,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True)
else:
logger.info(f'Convert by using checkpoint')
hf_mamba = None
convert_args = dict(dtype=args.dtype, )
if convert_from_ckpt:
convert_args['model_dir'] = args.model_dir
else:
convert_args['hf_mamba'] = hf_mamba
convert_args['mamba_version'] = mamba_version
convert_args['mamba_config'] = config
convert(0, args, convert_args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')
if __name__ == '__main__':
main()