TensorRT-LLMs/examples/gemma/convert_checkpoint.py
Kaiyu Xie 728cc0044b
Update TensorRT-LLM (#1233)
* Update TensorRT-LLM

---------

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-05 18:32:53 +08:00

850 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import argparse
import json
import logging
import math
import pathlib
import re
import time
import typing
import flax.traverse_util
import h5py
import numpy as np
import safetensors.numpy
import safetensors.torch
import sentencepiece as sp
import torch
import utils.params
import utils.transformer
from datasets import load_dataset
from easydict import EasyDict
import tensorrt_llm
from tensorrt_llm._utils import torch_to_numpy
from tensorrt_llm.models.gemma.smoothquant import *
from tensorrt_llm.models.gemma.weight import (dummy_weights_awq,
load_from_fp8_llama,
quantize_fp8_weights)
LOGGER = logging.getLogger("convert_checkpoint")
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-type",
type=str,
choices=["jax", "keras", "torch"])
parser.add_argument("--model-dir", type=pathlib.Path, required=True)
parser.add_argument("--output-model-dir", type=pathlib.Path, required=True)
parser.add_argument("--world-size",
type=int,
default=1,
help="world size, only support tensor parallelism now")
parser.add_argument(
"--use-weight-only-with-precision",
choices=["int8", "int4", "w4a8_awq", "w4a16_awq"],
help=
"help='Quantize weights for the various GEMMs to INT4/INT8. Define the precision for the weights.",
)
parser.add_argument("--dtype",
type=str,
choices=["float32", "bfloat16", "float16"])
parser.add_argument(
"--enable_fp8",
action="store_true",
help="Use FP8 Linear layer for Attention QKV/Dense and MLP.")
parser.add_argument(
"--fp8_kv_cache",
action="store_true",
help=
"By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV",
)
parser.add_argument(
"--ammo_quant_ckpt_path",
default=None,
help=
"Path of a directory to quantized model checkpoints in .safetensors format or \
path of a quantized model checkpoint in .npz format")
parser.add_argument('--use_smooth_quant',
default=False,
action="store_true",
help="Use smooth quant.")
parser.add_argument(
"--calibrate_kv_cache",
"-kv",
action="store_true",
help=
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
)
parser.add_argument(
'--per_channel',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
"--use_smooth_quant_plugin",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]")
parser.add_argument(
'--tokenizer_dir',
default=None,
help='tokenizer path; defaults to jax_model_dir if left unspecified')
args = parser.parse_args()
return args
class JAXParser:
def load_parameters(self, checkpoint_path: pathlib.Path):
checkpoint_path = checkpoint_path.absolute()
return utils.params.nest_params(
utils.params.param_remapper(
utils.params.load_params(checkpoint_path)))
def embedding_weights(self, ckpt_params):
return ckpt_params["transformer"]["embedder"]["input_embedding"]
def get_config(self, checkpoint_path, ckpt_params, num_embed):
return utils.transformer.TransformerConfig.from_params(
ckpt_params, num_embed=num_embed)
def rename_to_trt_llm(self, name: str):
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix, name = name.split(".", maxsplit=1)
assert prefix == "transformer"
sub_patterns = (
(r"embedder.input_embedding", r"vocab_embedding.weight"),
(r"layer_(\d+).pre_attention_norm.scale",
r"layers.\1.input_layernorm.weight"),
(r"layer_(\d+).attn.q_einsum.w", r"layers.\1.attention.qkv.weight"),
(r"layer_(\d+).attn.kv_einsum.w",
None), # drop as kv will be concatenated with q
(r"layer_(\d+).attn.qkv_einsum.w",
r"layers.\1.attention.qkv.weight"),
(r"layer_(\d+).attn.attn_vec_einsum.w",
r"layers.\1.attention.dense.weight"),
(r"layer_(\d+).mlp.gating_einsum", r"layers.\1.mlp.fc.weight"),
(r"layer_(\d+).mlp.linear", r"layers.\1.mlp.proj.weight"),
(r"layer_(\d+).pre_ffw_norm.scale",
r"layers.\1.post_layernorm.weight"),
(r"final_norm.scale", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params):
return flax.traverse_util.flatten_dict(params, sep=".")
class KerasParser:
def load_parameters(self, checkpoint_path: pathlib.Path):
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
weights_file = json.load(open(checkpoint_path / config_file))["weights"]
h5_path = checkpoint_path / weights_file
return h5py.File(h5_path, "r+")
def embedding_weights(self, ckpt_params):
return np.array(ckpt_params["layers/reversible_embedding/vars/0"])
def get_config(self, checkpoint_path, ckpt_params, num_embed):
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
config_old = json.load(open(checkpoint_path / config_file))["config"]
config_new = {}
config_new["num_layers"] = config_old["num_layers"]
config_new["num_embed"] = config_old["vocabulary_size"]
config_new["embed_dim"] = config_old["hidden_dim"]
config_new["hidden_dim"] = config_old["intermediate_dim"] // 2
config_new["num_heads"] = config_old["num_query_heads"]
config_new["head_dim"] = config_old["head_dim"]
config_new["num_kv_heads"] = config_old["num_key_value_heads"]
return EasyDict(config_new)
def rename_to_trt_llm(self, name: str):
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
name = name.replace("/gemma_decoder_block/", "/gemma_decoder_block_0/")
sub_patterns = (
(r"layers/reversible_embedding/vars/0", r"vocab_embedding.weight"),
(r"layers/gemma_decoder_block_(\d+)/pre_attention_norm/vars/0",
r"layers.\1.input_layernorm.weight"),
(r"layers/gemma_decoder_block_(\d+)/attention/query_dense/vars/0",
r"layers.\1.attention.qkv.weight"),
(r"layers/gemma_decoder_block_(\d+)/attention/key_dense/vars/0",
None), # drop as k will be concatenated with q
(r"layers/gemma_decoder_block_(\d+)/attention/value_dense/vars/0",
None), # drop as v will be concatenated with q
(r"layers/gemma_decoder_block_(\d+)/attention/output_dense/vars/0",
r"layers.\1.attention.dense.weight"),
(r"layers/gemma_decoder_block_(\d+)/gating_ffw/vars/0",
r"layers.\1.mlp.fc.weight"),
(r"layers/gemma_decoder_block_(\d+)/gating_ffw_2/vars/0",
None), # merged with above
(r"layers/gemma_decoder_block_(\d+)/ffw_linear/vars/0",
r"layers.\1.mlp.proj.weight"),
(r"layers/gemma_decoder_block_(\d+)/pre_ffw_norm/vars/0",
r"layers.\1.post_layernorm.weight"),
(r"layers/rms_normalization/vars/0", r"ln_f.weight"),
(r"optimizer/vars/(\d+)", None), # Not used
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params):
f_params = {}
def walk(name, obj):
if isinstance(obj, h5py.Dataset):
f_params[name] = np.array(obj)
params.visititems(walk)
return f_params
class TorchParser:
def load_parameters(self, checkpoint_path: pathlib.Path):
ckpt_path = list(checkpoint_path.glob('*.ckpt'))[0]
model_params = torch.load(ckpt_path)['model_state_dict']
model_params.pop('freqs_cis')
return model_params
def embedding_weights(self, ckpt_params):
return ckpt_params["embedder.weight"]
def get_config(self, checkpoint_path, ckpt_params, num_embed):
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
with open(checkpoint_path / config_file, 'r') as f:
json_str = f.read()
json_str = json_str.replace("'", "\"")
json_str = json_str.replace(",\n}", "\n}")
config_old = json.loads(json_str)
config_new = {}
config_new["num_layers"] = config_old["num_hidden_layers"]
config_new["num_embed"] = config_old["vocab_size"]
config_new["embed_dim"] = config_old["hidden_size"]
config_new["hidden_dim"] = config_old["intermediate_size"]
config_new["num_heads"] = config_old["num_attention_heads"]
config_new["head_dim"] = config_old["head_dim"]
config_new["num_kv_heads"] = config_old["num_key_value_heads"]
return EasyDict(config_new)
def rename_to_trt_llm(self, name: str):
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
sub_patterns = (
(r"embedder.weight", r"vocab_embedding.weight"),
(r"model.layers.(\d+).input_layernorm.weight",
r"layers.\1.input_layernorm.weight"),
(r"model.layers.(\d+).self_attn.qkv_proj.weight",
r"layers.\1.attention.qkv.weight"),
(r"model.layers.(\d+).self_attn.o_proj.weight",
r"layers.\1.attention.dense.weight"),
(r"model.layers.(\d+).mlp.gate_proj.weight",
r"layers.\1.mlp.fc.weight"),
(r"model.layers.(\d+).mlp.up_proj.weight",
None), # merged with above
(r"model.layers.(\d+).mlp.down_proj.weight",
r"layers.\1.mlp.proj.weight"),
(r"model.layers.(\d+).post_attention_layernorm.weight",
r"layers.\1.post_layernorm.weight"),
(r"model.norm.weight", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {name}")
def flatten_params(self, params):
f_params = {}
for k, v in params.items():
if v.dtype == torch.bfloat16:
v = v.float()
f_params[k] = torch_to_numpy(v)
return f_params
CKPT_PARSER = {'jax': JAXParser, 'keras': KerasParser, 'torch': TorchParser}
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
return np.split(v, tp_size, axis=dim)[idx]
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
name: str,
param: np.ndarray,
dtype: typing.Optional[np.dtype] = None):
assert name not in weights, f"{name} is already added."
if dtype is not None:
param = param.astype(dtype)
param = np.ascontiguousarray(param)
weights[name] = param
def quantize(param: np.ndarray,
quant_mode: tensorrt_llm.quantization.QuantMode):
if quant_mode.is_int8_weight_only():
quant_dtype = torch.int8
elif quant_mode.is_int4_weight_only():
quant_dtype = torch.quint4x2
else:
raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
if param.dtype == np.dtype("bfloat16"):
param = torch.from_numpy(param.astype(np.float32)).to(torch.bfloat16)
else:
param = torch.from_numpy(param)
param = param.t().contiguous()
# previously this fn was available in torch.ops.fastertransformer namespace
(
quantized_weights,
scales,
) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
param, quant_dtype)
if scales.dtype == torch.bfloat16:
scales = scales.to(torch.float32).numpy().astype("bfloat16")
else:
scales = scales.numpy()
return quantized_weights.numpy(), scales
def convert_from_checkpoint(
trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
model_dir: typing.Union[str, pathlib.Path],
ckpt_parser,
rank=0,
):
print("Loading weights...")
tik = time.time()
tp_rank = rank
tp_size = trt_llm_config.mapping.tp_size
hidden_size = trt_llm_config.hidden_size
head_dim = trt_llm_config.head_size
weights = {}
for model_file in [model_dir]:
LOGGER.debug(f"Loading directory {str(model_file)}...")
model_params = ckpt_parser.load_parameters(model_file)
model_params = ckpt_parser.flatten_params(model_params)
for name, param in model_params.items():
LOGGER.debug(f"Converting weight {name}...")
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
if trt_llm_name is None: # omit as used with other params
continue
if "attn.q_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert gqa_mode
# initial shape: (num_q_heads, hidden_size, head_dim)
q_param = param.transpose(1, 0, 2)
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
kv_name = name.replace("q_einsum", "kv_einsum")
kv_param = model_params[kv_name]
kv_param = kv_param.reshape(
trt_llm_config.num_key_value_heads * 2,
hidden_size,
head_dim,
).transpose(1, 0, 2)
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
qkv_param = np.concatenate([q_param, kv_param], axis=1)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
qkv_param = qkv_param.transpose(1, 0)
# If int8 kv enabled, weight-only quantization will be done later.
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "self_attn.qkv_proj" in name:
q_param, k_param, v_param = np.split(param, [
trt_llm_config.num_attention_heads *
trt_llm_config.head_size,
trt_llm_config.num_attention_heads *
trt_llm_config.head_size +
trt_llm_config.num_key_value_heads *
trt_llm_config.head_size
],
axis=0)
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
if not gqa_mode:
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_per_group_scaling():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attn.qkv_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert not gqa_mode
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
qkv_param = param.transpose(0, 1, 3, 2)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
qkv_param.shape[3])
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() \
and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attention/query_dense" in name:
# Keras specific KQV convert
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
if gqa_mode:
# initial shape: (num_q_heads, hidden_size, head_dim)
q_param = param.transpose(1, 0, 2)
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
k_name = name.replace("query", "key")
k_param = model_params[k_name]
v_name = name.replace("query", "value")
v_param = model_params[v_name]
kv_param = np.stack((k_param, v_param), axis=0)
kv_param = kv_param.reshape(
trt_llm_config.num_key_value_heads * 2,
hidden_size,
head_dim,
).transpose(1, 0, 2)
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
qkv_param = np.concatenate([q_param, kv_param], axis=1)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
qkv_param = qkv_param.transpose(1, 0)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight",
".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
else:
q_param = param
k_name = name.replace("query", "key")
k_param = model_params[k_name]
v_name = name.replace("query", "value")
v_param = model_params[v_name]
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
qkv_param = np.stack((q_param, k_param, v_param), axis=0)
qkv_param = qkv_param.transpose(0, 1, 3, 2)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
qkv_param.shape[3])
qkv_param = split_matrix_tp(qkv_param,
tp_size,
tp_rank,
dim=1)
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight",
".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attention.dense.weight" in trt_llm_name:
# initial shape: (num_heads, head_dim, hidden_size)
if len(param.shape) == 3:
param = param.reshape(-1, param.shape[2])
param = param.transpose(
1, 0) # (hidden_size, num_heads * head_dum)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "mlp.fc.weight" in trt_llm_name:
if isinstance(ckpt_parser, KerasParser):
# initial shape: (hidden_size, intermediate_size)
fc_param, gate_param = param, model_params[name.replace(
"gating_ffw", "gating_ffw_2")]
elif isinstance(ckpt_parser, TorchParser):
# initial shape: (intermediate_size, hidden_size)
fc_param, gate_param = param, model_params[name.replace(
"mlp.gate_proj", "mlp.up_proj")]
fc_param = fc_param.transpose(1, 0)
gate_param = gate_param.transpose(1, 0)
else:
# initial shape: (2, hidden_size, intermediate_size)
fc_param, gate_param = param[0], param[1]
fc_param = fc_param.transpose(1, 0)
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
fc_param_quantized, fc_param_scales = quantize(
fc_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
fc_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
fc_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, fc_param,
trt_llm_config.dtype)
gate_param = gate_param.transpose(1, 0)
gate_param = split_matrix_tp(gate_param,
tp_size,
tp_rank,
dim=0)
trt_llm_name = trt_llm_name.replace("mlp.fc.weight",
"mlp.gate.weight")
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
gate_param_quantized, gate_param_scales = quantize(
gate_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
gate_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
gate_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, gate_param,
trt_llm_config.dtype)
elif "mlp.proj.weight" in trt_llm_name:
if not isinstance(ckpt_parser, TorchParser):
# initial shape: (intermediate_size, hidden_size)
param = param.transpose(1, 0)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name:
if not trt_llm_config.share_embedding_table:
# TODO: safetensor doesn't allow to save a shared tensor.
# Currently, we clone the weight but to save the disk, it
# would be better to skip saving lm_head weights and
# handle it at the loading phase.
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, "lm_head.weight",
np.copy(lm_head), trt_llm_config.dtype)
param = np.multiply(
param.astype(np.float32),
math.sqrt(trt_llm_config.hidden_size),
)
if trt_llm_config.use_parallel_embedding:
assert trt_llm_config.vocab_size % tp_size == 0
param = split_matrix_tp(
param,
tp_size,
tp_rank,
dim=trt_llm_config.embedding_sharding_dim,
)
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"pre_attention_norm.scale",
"pre_ffw_norm.scale",
"final_norm.scale",
"pre_attention_norm/vars/0",
"pre_ffw_norm/vars/0",
"rms_normalization/vars/0",
"input_layernorm",
"post_attention_layernorm",
"model.norm.weight",
)):
param = param + 1.0 # upcasted to float32 in case of bfloat16
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
else:
raise RuntimeError(f"Unhandled {name} module weights")
del model_params
print(
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
)
return weights
def convert(worker_rank, args, convert_kwargs):
for rank in range(worker_rank, args.world_size):
weights = convert_from_checkpoint(rank=rank, **convert_kwargs)
trt_llm_config = convert_kwargs.get("trt_llm_config")
if args.use_smooth_quant_plugin is not None or args.calibrate_kv_cache:
qkv_para = {}
smoother = {}
dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0')
tokenizer = sp.SentencePieceProcessor(model_file=args.tokenizer_dir)
hf_model = create_model_from_config(trt_llm_config, weights)
act_range = capture_activation_range(hf_model, tokenizer, dataset)
if args.use_smooth_quant_plugin is not None:
smooth_model(hf_model, act_range, args.use_smooth_quant_plugin,
qkv_para, smoother)
weights = convert_hf_model(
hf_model, trt_llm_config.mapping, trt_llm_config.vocab_size,
args.dtype, False, 0,
args.use_weight_only_with_precision != None,
torch.int8 if args.use_weight_only_with_precision == 'int8' else
torch.quint4x2, args.use_smooth_quant_plugin is not None,
args.per_channel, args.per_token, args.calibrate_kv_cache,
act_range, qkv_para, smoother)
safetensors.torch.save_file(
weights, args.output_model_dir / f"rank{rank}.safetensors")
return
use_awq = False
if args.use_weight_only_with_precision:
if args.use_weight_only_with_precision.endswith("awq"):
use_awq = True
if use_awq:
weights = dummy_weights_awq(
weights=weights,
precision=args.use_weight_only_with_precision,
trt_llm_config=trt_llm_config,
group_size=128)
elif args.enable_fp8 or args.fp8_kv_cache:
weight_scales = quantize_fp8_weights(
weights, trt_llm_config.num_hidden_layers,
trt_llm_config.mapping)
scales = load_from_fp8_llama(args.ammo_quant_ckpt_path,
trt_llm_config.num_hidden_layers,
trt_llm_config.mapping,
args.fp8_kv_cache, weight_scales)
weights.update(scales)
safetensors.numpy.save_file(
weights, args.output_model_dir / f"rank{rank}.safetensors")
def main():
args = parse_arguments()
tik = time.time()
print(f"Loading source parameters from {args.model_dir.absolute()}")
ckpt_parser = CKPT_PARSER[args.ckpt_type]()
ckpt_params = ckpt_parser.load_parameters(args.model_dir)
input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
num_embed, _ = input_embedding_weights.shape
ckpt_params_dtype = str(
input_embedding_weights.dtype).split(".")[-1] # np.bfloat16 -> bfloat16
ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params, num_embed)
# 2B TransformerConfig(num_layers=18, num_embed=256128, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=1)
# 7B TransformerConfig(...)
print(f"Source configuration determined from parameters: {ckpt_config}")
quant_kwargs = {}
quant_algo = None
kv_cache_quant_algo = None
if args.use_weight_only_with_precision:
quant_algo = {
"int8": "W8A16",
"int4": "W4A16",
"w4a8_awq": "W4A8_AWQ",
"w4a16_awq": "W4A16_AWQ",
}[args.use_weight_only_with_precision]
elif args.enable_fp8:
quant_algo = "FP8"
elif args.use_smooth_quant:
quant_algo = "W8A8_SQ_PER_CHANNEL"
if args.fp8_kv_cache:
kv_cache_quant_algo = "FP8"
if args.calibrate_kv_cache:
kv_cache_quant_algo = "INT8"
if args.use_smooth_quant:
quant_algo = "W8A8_SQ_PER_CHANNEL"
elif args.use_smooth_quant_plugin is not None:
if args.per_token and args.per_channel:
quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN'
elif not args.per_token and not args.per_channel:
quant_algo = 'W8A8_SQ_PER_TENSOR_PLUGIN'
elif not args.per_token and args.per_channel:
quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN'
elif args.per_token and not args.per_channel:
quant_algo = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN'
quant_kwargs.update(sq_use_plugin=True)
quant_kwargs.update(quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo)
if args.use_weight_only_with_precision:
if args.use_weight_only_with_precision.endswith("awq"):
quant_kwargs.update(has_zero_point=False,
pre_quant_scale=True,
exclude_modules=["lm_head"])
trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig(
architecture="GemmaForCausalLM",
dtype=args.dtype or ckpt_params_dtype,
logits_dtype="float32",
vocab_size=ckpt_config.num_embed,
max_position_embeddings=8192,
hidden_size=ckpt_config.embed_dim,
num_hidden_layers=ckpt_config.num_layers,
num_attention_heads=ckpt_config.num_heads,
num_key_value_heads=ckpt_config.num_kv_heads,
head_size=ckpt_config.head_dim,
hidden_act="gelu",
intermediate_size=ckpt_config.hidden_dim,
norm_epsilon=1e-6, # hard-coded in RMSNorm from gemma/layers.py
position_embedding_type="rope_gpt_neox",
world_size=args.world_size,
tp_size=args.world_size,
pp_size=1,
quantization=quant_kwargs,
)
trt_llm_config_dict = trt_llm_config.to_dict()
print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")
config_path = args.output_model_dir / "config.json"
config_path.parent.mkdir(exist_ok=True, parents=True)
LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}")
with config_path.open("w") as config_file:
json.dump(trt_llm_config_dict, config_file, indent=4)
convert_args = dict(trt_llm_config=trt_llm_config,
model_dir=args.model_dir,
ckpt_parser=ckpt_parser)
convert(0, args, convert_args)
elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
print(f"Total time of converting checkpoints: {elapsed}")
if __name__ == "__main__":
main()