mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
850 lines
38 KiB
Python
850 lines
38 KiB
Python
#!/usr/bin/env python3
|
||
import argparse
|
||
import json
|
||
import logging
|
||
import math
|
||
import pathlib
|
||
import re
|
||
import time
|
||
import typing
|
||
|
||
import flax.traverse_util
|
||
import h5py
|
||
import numpy as np
|
||
import safetensors.numpy
|
||
import safetensors.torch
|
||
import sentencepiece as sp
|
||
import torch
|
||
import utils.params
|
||
import utils.transformer
|
||
from datasets import load_dataset
|
||
from easydict import EasyDict
|
||
|
||
import tensorrt_llm
|
||
from tensorrt_llm._utils import torch_to_numpy
|
||
from tensorrt_llm.models.gemma.smoothquant import *
|
||
from tensorrt_llm.models.gemma.weight import (dummy_weights_awq,
|
||
load_from_fp8_llama,
|
||
quantize_fp8_weights)
|
||
|
||
LOGGER = logging.getLogger("convert_checkpoint")
|
||
|
||
|
||
def parse_arguments():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--ckpt-type",
|
||
type=str,
|
||
choices=["jax", "keras", "torch"])
|
||
parser.add_argument("--model-dir", type=pathlib.Path, required=True)
|
||
parser.add_argument("--output-model-dir", type=pathlib.Path, required=True)
|
||
parser.add_argument("--world-size",
|
||
type=int,
|
||
default=1,
|
||
help="world size, only support tensor parallelism now")
|
||
parser.add_argument(
|
||
"--use-weight-only-with-precision",
|
||
choices=["int8", "int4", "w4a8_awq", "w4a16_awq"],
|
||
help=
|
||
"help='Quantize weights for the various GEMMs to INT4/INT8. Define the precision for the weights.",
|
||
)
|
||
parser.add_argument("--dtype",
|
||
type=str,
|
||
choices=["float32", "bfloat16", "float16"])
|
||
parser.add_argument(
|
||
"--enable_fp8",
|
||
action="store_true",
|
||
help="Use FP8 Linear layer for Attention QKV/Dense and MLP.")
|
||
parser.add_argument(
|
||
"--fp8_kv_cache",
|
||
action="store_true",
|
||
help=
|
||
"By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV",
|
||
)
|
||
parser.add_argument(
|
||
"--ammo_quant_ckpt_path",
|
||
default=None,
|
||
help=
|
||
"Path of a directory to quantized model checkpoints in .safetensors format or \
|
||
path of a quantized model checkpoint in .npz format")
|
||
parser.add_argument('--use_smooth_quant',
|
||
default=False,
|
||
action="store_true",
|
||
help="Use smooth quant.")
|
||
parser.add_argument(
|
||
"--calibrate_kv_cache",
|
||
"-kv",
|
||
action="store_true",
|
||
help=
|
||
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
|
||
)
|
||
parser.add_argument(
|
||
'--per_channel',
|
||
default=False,
|
||
action="store_true",
|
||
help=
|
||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||
'per_channel instead uses a different static scaling factor for each channel. '
|
||
'The latter is usually more accurate, but a little slower.')
|
||
parser.add_argument(
|
||
'--per_token',
|
||
default=False,
|
||
action="store_true",
|
||
help=
|
||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||
'The latter is usually more accurate, but a little slower.')
|
||
parser.add_argument(
|
||
"--use_smooth_quant_plugin",
|
||
"-sq",
|
||
type=float,
|
||
default=None,
|
||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||
" to Smoothquant the model, and output int8 weights."
|
||
" A good first try is 0.5. Must be in [0, 1]")
|
||
parser.add_argument(
|
||
'--tokenizer_dir',
|
||
default=None,
|
||
help='tokenizer path; defaults to jax_model_dir if left unspecified')
|
||
|
||
args = parser.parse_args()
|
||
|
||
return args
|
||
|
||
|
||
class JAXParser:
|
||
|
||
def load_parameters(self, checkpoint_path: pathlib.Path):
|
||
checkpoint_path = checkpoint_path.absolute()
|
||
return utils.params.nest_params(
|
||
utils.params.param_remapper(
|
||
utils.params.load_params(checkpoint_path)))
|
||
|
||
def embedding_weights(self, ckpt_params):
|
||
return ckpt_params["transformer"]["embedder"]["input_embedding"]
|
||
|
||
def get_config(self, checkpoint_path, ckpt_params, num_embed):
|
||
return utils.transformer.TransformerConfig.from_params(
|
||
ckpt_params, num_embed=num_embed)
|
||
|
||
def rename_to_trt_llm(self, name: str):
|
||
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
||
prefix, name = name.split(".", maxsplit=1)
|
||
assert prefix == "transformer"
|
||
sub_patterns = (
|
||
(r"embedder.input_embedding", r"vocab_embedding.weight"),
|
||
(r"layer_(\d+).pre_attention_norm.scale",
|
||
r"layers.\1.input_layernorm.weight"),
|
||
(r"layer_(\d+).attn.q_einsum.w", r"layers.\1.attention.qkv.weight"),
|
||
(r"layer_(\d+).attn.kv_einsum.w",
|
||
None), # drop as kv will be concatenated with q
|
||
(r"layer_(\d+).attn.qkv_einsum.w",
|
||
r"layers.\1.attention.qkv.weight"),
|
||
(r"layer_(\d+).attn.attn_vec_einsum.w",
|
||
r"layers.\1.attention.dense.weight"),
|
||
(r"layer_(\d+).mlp.gating_einsum", r"layers.\1.mlp.fc.weight"),
|
||
(r"layer_(\d+).mlp.linear", r"layers.\1.mlp.proj.weight"),
|
||
(r"layer_(\d+).pre_ffw_norm.scale",
|
||
r"layers.\1.post_layernorm.weight"),
|
||
(r"final_norm.scale", r"ln_f.weight"),
|
||
)
|
||
|
||
for source, target in sub_patterns:
|
||
if re.match(source, name):
|
||
if target is None:
|
||
return target
|
||
else:
|
||
name = re.sub(source, target, name)
|
||
return ".".join((prefix, name))
|
||
else:
|
||
raise ValueError(f"Don't know how to rename {prefix}.{name}")
|
||
|
||
def flatten_params(self, params):
|
||
return flax.traverse_util.flatten_dict(params, sep=".")
|
||
|
||
|
||
class KerasParser:
|
||
|
||
def load_parameters(self, checkpoint_path: pathlib.Path):
|
||
checkpoint_path = checkpoint_path.absolute()
|
||
config_file = "config.json"
|
||
weights_file = json.load(open(checkpoint_path / config_file))["weights"]
|
||
h5_path = checkpoint_path / weights_file
|
||
return h5py.File(h5_path, "r+")
|
||
|
||
def embedding_weights(self, ckpt_params):
|
||
return np.array(ckpt_params["layers/reversible_embedding/vars/0"])
|
||
|
||
def get_config(self, checkpoint_path, ckpt_params, num_embed):
|
||
checkpoint_path = checkpoint_path.absolute()
|
||
config_file = "config.json"
|
||
config_old = json.load(open(checkpoint_path / config_file))["config"]
|
||
config_new = {}
|
||
config_new["num_layers"] = config_old["num_layers"]
|
||
config_new["num_embed"] = config_old["vocabulary_size"]
|
||
config_new["embed_dim"] = config_old["hidden_dim"]
|
||
config_new["hidden_dim"] = config_old["intermediate_dim"] // 2
|
||
config_new["num_heads"] = config_old["num_query_heads"]
|
||
config_new["head_dim"] = config_old["head_dim"]
|
||
config_new["num_kv_heads"] = config_old["num_key_value_heads"]
|
||
return EasyDict(config_new)
|
||
|
||
def rename_to_trt_llm(self, name: str):
|
||
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
||
prefix = "transformer"
|
||
name = name.replace("/gemma_decoder_block/", "/gemma_decoder_block_0/")
|
||
sub_patterns = (
|
||
(r"layers/reversible_embedding/vars/0", r"vocab_embedding.weight"),
|
||
(r"layers/gemma_decoder_block_(\d+)/pre_attention_norm/vars/0",
|
||
r"layers.\1.input_layernorm.weight"),
|
||
(r"layers/gemma_decoder_block_(\d+)/attention/query_dense/vars/0",
|
||
r"layers.\1.attention.qkv.weight"),
|
||
(r"layers/gemma_decoder_block_(\d+)/attention/key_dense/vars/0",
|
||
None), # drop as k will be concatenated with q
|
||
(r"layers/gemma_decoder_block_(\d+)/attention/value_dense/vars/0",
|
||
None), # drop as v will be concatenated with q
|
||
(r"layers/gemma_decoder_block_(\d+)/attention/output_dense/vars/0",
|
||
r"layers.\1.attention.dense.weight"),
|
||
(r"layers/gemma_decoder_block_(\d+)/gating_ffw/vars/0",
|
||
r"layers.\1.mlp.fc.weight"),
|
||
(r"layers/gemma_decoder_block_(\d+)/gating_ffw_2/vars/0",
|
||
None), # merged with above
|
||
(r"layers/gemma_decoder_block_(\d+)/ffw_linear/vars/0",
|
||
r"layers.\1.mlp.proj.weight"),
|
||
(r"layers/gemma_decoder_block_(\d+)/pre_ffw_norm/vars/0",
|
||
r"layers.\1.post_layernorm.weight"),
|
||
(r"layers/rms_normalization/vars/0", r"ln_f.weight"),
|
||
(r"optimizer/vars/(\d+)", None), # Not used
|
||
)
|
||
|
||
for source, target in sub_patterns:
|
||
if re.match(source, name):
|
||
if target is None:
|
||
return target
|
||
else:
|
||
name = re.sub(source, target, name)
|
||
return ".".join((prefix, name))
|
||
else:
|
||
raise ValueError(f"Don't know how to rename {prefix}.{name}")
|
||
|
||
def flatten_params(self, params):
|
||
f_params = {}
|
||
|
||
def walk(name, obj):
|
||
if isinstance(obj, h5py.Dataset):
|
||
f_params[name] = np.array(obj)
|
||
|
||
params.visititems(walk)
|
||
return f_params
|
||
|
||
|
||
class TorchParser:
|
||
|
||
def load_parameters(self, checkpoint_path: pathlib.Path):
|
||
ckpt_path = list(checkpoint_path.glob('*.ckpt'))[0]
|
||
model_params = torch.load(ckpt_path)['model_state_dict']
|
||
model_params.pop('freqs_cis')
|
||
return model_params
|
||
|
||
def embedding_weights(self, ckpt_params):
|
||
return ckpt_params["embedder.weight"]
|
||
|
||
def get_config(self, checkpoint_path, ckpt_params, num_embed):
|
||
checkpoint_path = checkpoint_path.absolute()
|
||
config_file = "config.json"
|
||
with open(checkpoint_path / config_file, 'r') as f:
|
||
json_str = f.read()
|
||
json_str = json_str.replace("'", "\"")
|
||
json_str = json_str.replace(",\n}", "\n}")
|
||
config_old = json.loads(json_str)
|
||
config_new = {}
|
||
config_new["num_layers"] = config_old["num_hidden_layers"]
|
||
config_new["num_embed"] = config_old["vocab_size"]
|
||
config_new["embed_dim"] = config_old["hidden_size"]
|
||
config_new["hidden_dim"] = config_old["intermediate_size"]
|
||
config_new["num_heads"] = config_old["num_attention_heads"]
|
||
config_new["head_dim"] = config_old["head_dim"]
|
||
config_new["num_kv_heads"] = config_old["num_key_value_heads"]
|
||
return EasyDict(config_new)
|
||
|
||
def rename_to_trt_llm(self, name: str):
|
||
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
||
prefix = "transformer"
|
||
sub_patterns = (
|
||
(r"embedder.weight", r"vocab_embedding.weight"),
|
||
(r"model.layers.(\d+).input_layernorm.weight",
|
||
r"layers.\1.input_layernorm.weight"),
|
||
(r"model.layers.(\d+).self_attn.qkv_proj.weight",
|
||
r"layers.\1.attention.qkv.weight"),
|
||
(r"model.layers.(\d+).self_attn.o_proj.weight",
|
||
r"layers.\1.attention.dense.weight"),
|
||
(r"model.layers.(\d+).mlp.gate_proj.weight",
|
||
r"layers.\1.mlp.fc.weight"),
|
||
(r"model.layers.(\d+).mlp.up_proj.weight",
|
||
None), # merged with above
|
||
(r"model.layers.(\d+).mlp.down_proj.weight",
|
||
r"layers.\1.mlp.proj.weight"),
|
||
(r"model.layers.(\d+).post_attention_layernorm.weight",
|
||
r"layers.\1.post_layernorm.weight"),
|
||
(r"model.norm.weight", r"ln_f.weight"),
|
||
)
|
||
|
||
for source, target in sub_patterns:
|
||
if re.match(source, name):
|
||
if target is None:
|
||
return target
|
||
else:
|
||
name = re.sub(source, target, name)
|
||
return ".".join((prefix, name))
|
||
else:
|
||
raise ValueError(f"Don't know how to rename {name}")
|
||
|
||
def flatten_params(self, params):
|
||
f_params = {}
|
||
for k, v in params.items():
|
||
if v.dtype == torch.bfloat16:
|
||
v = v.float()
|
||
f_params[k] = torch_to_numpy(v)
|
||
return f_params
|
||
|
||
|
||
CKPT_PARSER = {'jax': JAXParser, 'keras': KerasParser, 'torch': TorchParser}
|
||
|
||
|
||
def split(v, tp_size, idx, dim=0):
|
||
if tp_size == 1:
|
||
return v
|
||
return np.split(v, tp_size, axis=dim)[idx]
|
||
|
||
|
||
def split_matrix_tp(v, tensor_parallel, rank, dim):
|
||
return split(v, tensor_parallel, rank, dim=dim)
|
||
|
||
|
||
def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
|
||
name: str,
|
||
param: np.ndarray,
|
||
dtype: typing.Optional[np.dtype] = None):
|
||
assert name not in weights, f"{name} is already added."
|
||
if dtype is not None:
|
||
param = param.astype(dtype)
|
||
param = np.ascontiguousarray(param)
|
||
weights[name] = param
|
||
|
||
|
||
def quantize(param: np.ndarray,
|
||
quant_mode: tensorrt_llm.quantization.QuantMode):
|
||
if quant_mode.is_int8_weight_only():
|
||
quant_dtype = torch.int8
|
||
elif quant_mode.is_int4_weight_only():
|
||
quant_dtype = torch.quint4x2
|
||
else:
|
||
raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
|
||
|
||
if param.dtype == np.dtype("bfloat16"):
|
||
param = torch.from_numpy(param.astype(np.float32)).to(torch.bfloat16)
|
||
else:
|
||
param = torch.from_numpy(param)
|
||
param = param.t().contiguous()
|
||
|
||
# previously this fn was available in torch.ops.fastertransformer namespace
|
||
(
|
||
quantized_weights,
|
||
scales,
|
||
) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
||
param, quant_dtype)
|
||
|
||
if scales.dtype == torch.bfloat16:
|
||
scales = scales.to(torch.float32).numpy().astype("bfloat16")
|
||
else:
|
||
scales = scales.numpy()
|
||
return quantized_weights.numpy(), scales
|
||
|
||
|
||
def convert_from_checkpoint(
|
||
trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
|
||
model_dir: typing.Union[str, pathlib.Path],
|
||
ckpt_parser,
|
||
rank=0,
|
||
):
|
||
print("Loading weights...")
|
||
tik = time.time()
|
||
|
||
tp_rank = rank
|
||
tp_size = trt_llm_config.mapping.tp_size
|
||
hidden_size = trt_llm_config.hidden_size
|
||
head_dim = trt_llm_config.head_size
|
||
|
||
weights = {}
|
||
for model_file in [model_dir]:
|
||
LOGGER.debug(f"Loading directory {str(model_file)}...")
|
||
model_params = ckpt_parser.load_parameters(model_file)
|
||
model_params = ckpt_parser.flatten_params(model_params)
|
||
|
||
for name, param in model_params.items():
|
||
LOGGER.debug(f"Converting weight {name}...")
|
||
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
|
||
if trt_llm_name is None: # omit as used with other params
|
||
continue
|
||
|
||
if "attn.q_einsum" in name:
|
||
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
||
assert gqa_mode
|
||
|
||
# initial shape: (num_q_heads, hidden_size, head_dim)
|
||
q_param = param.transpose(1, 0, 2)
|
||
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
|
||
|
||
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
|
||
kv_name = name.replace("q_einsum", "kv_einsum")
|
||
kv_param = model_params[kv_name]
|
||
kv_param = kv_param.reshape(
|
||
trt_llm_config.num_key_value_heads * 2,
|
||
hidden_size,
|
||
head_dim,
|
||
).transpose(1, 0, 2)
|
||
|
||
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
|
||
qkv_param = np.concatenate([q_param, kv_param], axis=1)
|
||
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
|
||
qkv_param = qkv_param.transpose(1, 0)
|
||
|
||
# If int8 kv enabled, weight-only quantization will be done later.
|
||
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
||
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
qkv_param_quantized, qkv_param_scales = quantize(
|
||
qkv_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
qkv_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
qkv_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
||
trt_llm_config.dtype)
|
||
elif "self_attn.qkv_proj" in name:
|
||
q_param, k_param, v_param = np.split(param, [
|
||
trt_llm_config.num_attention_heads *
|
||
trt_llm_config.head_size,
|
||
trt_llm_config.num_attention_heads *
|
||
trt_llm_config.head_size +
|
||
trt_llm_config.num_key_value_heads *
|
||
trt_llm_config.head_size
|
||
],
|
||
axis=0)
|
||
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
||
|
||
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
|
||
if not gqa_mode:
|
||
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
|
||
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
|
||
|
||
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
|
||
if trt_llm_config.quant_mode.is_weight_only(
|
||
) and not trt_llm_config.quant_mode.has_per_group_scaling():
|
||
qkv_param_quantized, qkv_param_scales = quantize(
|
||
qkv_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
qkv_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
qkv_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
||
trt_llm_config.dtype)
|
||
elif "attn.qkv_einsum" in name:
|
||
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
||
assert not gqa_mode
|
||
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
|
||
qkv_param = param.transpose(0, 1, 3, 2)
|
||
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
|
||
qkv_param.shape[3])
|
||
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
|
||
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
|
||
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() \
|
||
and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
qkv_param_quantized, qkv_param_scales = quantize(
|
||
qkv_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
qkv_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
qkv_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
||
trt_llm_config.dtype)
|
||
elif "attention/query_dense" in name:
|
||
# Keras specific KQV convert
|
||
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
||
if gqa_mode:
|
||
|
||
# initial shape: (num_q_heads, hidden_size, head_dim)
|
||
q_param = param.transpose(1, 0, 2)
|
||
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
|
||
|
||
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
|
||
k_name = name.replace("query", "key")
|
||
k_param = model_params[k_name]
|
||
v_name = name.replace("query", "value")
|
||
v_param = model_params[v_name]
|
||
kv_param = np.stack((k_param, v_param), axis=0)
|
||
|
||
kv_param = kv_param.reshape(
|
||
trt_llm_config.num_key_value_heads * 2,
|
||
hidden_size,
|
||
head_dim,
|
||
).transpose(1, 0, 2)
|
||
|
||
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
|
||
qkv_param = np.concatenate([q_param, kv_param], axis=1)
|
||
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
|
||
qkv_param = qkv_param.transpose(1, 0)
|
||
|
||
if trt_llm_config.quant_mode.is_weight_only(
|
||
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
qkv_param_quantized, qkv_param_scales = quantize(
|
||
qkv_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
qkv_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight",
|
||
".per_channel_scale"),
|
||
qkv_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
||
trt_llm_config.dtype)
|
||
else:
|
||
q_param = param
|
||
k_name = name.replace("query", "key")
|
||
k_param = model_params[k_name]
|
||
v_name = name.replace("query", "value")
|
||
v_param = model_params[v_name]
|
||
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
|
||
qkv_param = np.stack((q_param, k_param, v_param), axis=0)
|
||
qkv_param = qkv_param.transpose(0, 1, 3, 2)
|
||
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
|
||
qkv_param.shape[3])
|
||
qkv_param = split_matrix_tp(qkv_param,
|
||
tp_size,
|
||
tp_rank,
|
||
dim=1)
|
||
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
|
||
if trt_llm_config.quant_mode.is_weight_only(
|
||
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
qkv_param_quantized, qkv_param_scales = quantize(
|
||
qkv_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
qkv_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight",
|
||
".per_channel_scale"),
|
||
qkv_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
||
trt_llm_config.dtype)
|
||
elif "attention.dense.weight" in trt_llm_name:
|
||
# initial shape: (num_heads, head_dim, hidden_size)
|
||
if len(param.shape) == 3:
|
||
param = param.reshape(-1, param.shape[2])
|
||
param = param.transpose(
|
||
1, 0) # (hidden_size, num_heads * head_dum)
|
||
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
|
||
if trt_llm_config.quant_mode.is_weight_only(
|
||
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
param_quantized, param_scales = quantize(
|
||
param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, param,
|
||
trt_llm_config.dtype)
|
||
elif "mlp.fc.weight" in trt_llm_name:
|
||
if isinstance(ckpt_parser, KerasParser):
|
||
# initial shape: (hidden_size, intermediate_size)
|
||
fc_param, gate_param = param, model_params[name.replace(
|
||
"gating_ffw", "gating_ffw_2")]
|
||
elif isinstance(ckpt_parser, TorchParser):
|
||
# initial shape: (intermediate_size, hidden_size)
|
||
fc_param, gate_param = param, model_params[name.replace(
|
||
"mlp.gate_proj", "mlp.up_proj")]
|
||
fc_param = fc_param.transpose(1, 0)
|
||
gate_param = gate_param.transpose(1, 0)
|
||
else:
|
||
# initial shape: (2, hidden_size, intermediate_size)
|
||
fc_param, gate_param = param[0], param[1]
|
||
fc_param = fc_param.transpose(1, 0)
|
||
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
|
||
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
||
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
fc_param_quantized, fc_param_scales = quantize(
|
||
fc_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
fc_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
fc_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, fc_param,
|
||
trt_llm_config.dtype)
|
||
|
||
gate_param = gate_param.transpose(1, 0)
|
||
gate_param = split_matrix_tp(gate_param,
|
||
tp_size,
|
||
tp_rank,
|
||
dim=0)
|
||
trt_llm_name = trt_llm_name.replace("mlp.fc.weight",
|
||
"mlp.gate.weight")
|
||
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
||
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
gate_param_quantized, gate_param_scales = quantize(
|
||
gate_param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name,
|
||
gate_param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
gate_param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, gate_param,
|
||
trt_llm_config.dtype)
|
||
elif "mlp.proj.weight" in trt_llm_name:
|
||
if not isinstance(ckpt_parser, TorchParser):
|
||
# initial shape: (intermediate_size, hidden_size)
|
||
param = param.transpose(1, 0)
|
||
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
|
||
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
||
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
||
param_quantized, param_scales = quantize(
|
||
param, trt_llm_config.quant_mode)
|
||
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
|
||
add_trt_llm_weight(
|
||
weights,
|
||
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
||
param_scales,
|
||
trt_llm_config.dtype,
|
||
)
|
||
else:
|
||
add_trt_llm_weight(weights, trt_llm_name, param,
|
||
trt_llm_config.dtype)
|
||
elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name:
|
||
if not trt_llm_config.share_embedding_table:
|
||
# TODO: safetensor doesn't allow to save a shared tensor.
|
||
# Currently, we clone the weight but to save the disk, it
|
||
# would be better to skip saving lm_head weights and
|
||
# handle it at the loading phase.
|
||
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
|
||
add_trt_llm_weight(weights, "lm_head.weight",
|
||
np.copy(lm_head), trt_llm_config.dtype)
|
||
|
||
param = np.multiply(
|
||
param.astype(np.float32),
|
||
math.sqrt(trt_llm_config.hidden_size),
|
||
)
|
||
if trt_llm_config.use_parallel_embedding:
|
||
assert trt_llm_config.vocab_size % tp_size == 0
|
||
param = split_matrix_tp(
|
||
param,
|
||
tp_size,
|
||
tp_rank,
|
||
dim=trt_llm_config.embedding_sharding_dim,
|
||
)
|
||
add_trt_llm_weight(weights, trt_llm_name, param,
|
||
trt_llm_config.dtype)
|
||
elif any(keyword in name for keyword in (
|
||
"pre_attention_norm.scale",
|
||
"pre_ffw_norm.scale",
|
||
"final_norm.scale",
|
||
"pre_attention_norm/vars/0",
|
||
"pre_ffw_norm/vars/0",
|
||
"rms_normalization/vars/0",
|
||
"input_layernorm",
|
||
"post_attention_layernorm",
|
||
"model.norm.weight",
|
||
)):
|
||
param = param + 1.0 # upcasted to float32 in case of bfloat16
|
||
add_trt_llm_weight(weights, trt_llm_name, param,
|
||
trt_llm_config.dtype)
|
||
else:
|
||
raise RuntimeError(f"Unhandled {name} module weights")
|
||
del model_params
|
||
|
||
print(
|
||
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
|
||
)
|
||
return weights
|
||
|
||
|
||
def convert(worker_rank, args, convert_kwargs):
|
||
for rank in range(worker_rank, args.world_size):
|
||
weights = convert_from_checkpoint(rank=rank, **convert_kwargs)
|
||
trt_llm_config = convert_kwargs.get("trt_llm_config")
|
||
if args.use_smooth_quant_plugin is not None or args.calibrate_kv_cache:
|
||
qkv_para = {}
|
||
smoother = {}
|
||
dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0')
|
||
tokenizer = sp.SentencePieceProcessor(model_file=args.tokenizer_dir)
|
||
hf_model = create_model_from_config(trt_llm_config, weights)
|
||
act_range = capture_activation_range(hf_model, tokenizer, dataset)
|
||
if args.use_smooth_quant_plugin is not None:
|
||
smooth_model(hf_model, act_range, args.use_smooth_quant_plugin,
|
||
qkv_para, smoother)
|
||
weights = convert_hf_model(
|
||
hf_model, trt_llm_config.mapping, trt_llm_config.vocab_size,
|
||
args.dtype, False, 0,
|
||
args.use_weight_only_with_precision != None,
|
||
torch.int8 if args.use_weight_only_with_precision == 'int8' else
|
||
torch.quint4x2, args.use_smooth_quant_plugin is not None,
|
||
args.per_channel, args.per_token, args.calibrate_kv_cache,
|
||
act_range, qkv_para, smoother)
|
||
safetensors.torch.save_file(
|
||
weights, args.output_model_dir / f"rank{rank}.safetensors")
|
||
return
|
||
|
||
use_awq = False
|
||
if args.use_weight_only_with_precision:
|
||
if args.use_weight_only_with_precision.endswith("awq"):
|
||
use_awq = True
|
||
if use_awq:
|
||
weights = dummy_weights_awq(
|
||
weights=weights,
|
||
precision=args.use_weight_only_with_precision,
|
||
trt_llm_config=trt_llm_config,
|
||
group_size=128)
|
||
elif args.enable_fp8 or args.fp8_kv_cache:
|
||
weight_scales = quantize_fp8_weights(
|
||
weights, trt_llm_config.num_hidden_layers,
|
||
trt_llm_config.mapping)
|
||
scales = load_from_fp8_llama(args.ammo_quant_ckpt_path,
|
||
trt_llm_config.num_hidden_layers,
|
||
trt_llm_config.mapping,
|
||
args.fp8_kv_cache, weight_scales)
|
||
weights.update(scales)
|
||
|
||
safetensors.numpy.save_file(
|
||
weights, args.output_model_dir / f"rank{rank}.safetensors")
|
||
|
||
|
||
def main():
|
||
args = parse_arguments()
|
||
|
||
tik = time.time()
|
||
|
||
print(f"Loading source parameters from {args.model_dir.absolute()}")
|
||
ckpt_parser = CKPT_PARSER[args.ckpt_type]()
|
||
ckpt_params = ckpt_parser.load_parameters(args.model_dir)
|
||
input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params)
|
||
num_embed, _ = input_embedding_weights.shape
|
||
ckpt_params_dtype = str(
|
||
input_embedding_weights.dtype).split(".")[-1] # np.bfloat16 -> bfloat16
|
||
ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params, num_embed)
|
||
# 2B TransformerConfig(num_layers=18, num_embed=256128, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=1)
|
||
# 7B TransformerConfig(...)
|
||
|
||
print(f"Source configuration determined from parameters: {ckpt_config}")
|
||
|
||
quant_kwargs = {}
|
||
quant_algo = None
|
||
kv_cache_quant_algo = None
|
||
if args.use_weight_only_with_precision:
|
||
quant_algo = {
|
||
"int8": "W8A16",
|
||
"int4": "W4A16",
|
||
"w4a8_awq": "W4A8_AWQ",
|
||
"w4a16_awq": "W4A16_AWQ",
|
||
}[args.use_weight_only_with_precision]
|
||
elif args.enable_fp8:
|
||
quant_algo = "FP8"
|
||
elif args.use_smooth_quant:
|
||
quant_algo = "W8A8_SQ_PER_CHANNEL"
|
||
|
||
if args.fp8_kv_cache:
|
||
kv_cache_quant_algo = "FP8"
|
||
if args.calibrate_kv_cache:
|
||
kv_cache_quant_algo = "INT8"
|
||
if args.use_smooth_quant:
|
||
quant_algo = "W8A8_SQ_PER_CHANNEL"
|
||
elif args.use_smooth_quant_plugin is not None:
|
||
if args.per_token and args.per_channel:
|
||
quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN'
|
||
elif not args.per_token and not args.per_channel:
|
||
quant_algo = 'W8A8_SQ_PER_TENSOR_PLUGIN'
|
||
elif not args.per_token and args.per_channel:
|
||
quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN'
|
||
elif args.per_token and not args.per_channel:
|
||
quant_algo = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN'
|
||
quant_kwargs.update(sq_use_plugin=True)
|
||
|
||
quant_kwargs.update(quant_algo=quant_algo,
|
||
kv_cache_quant_algo=kv_cache_quant_algo)
|
||
if args.use_weight_only_with_precision:
|
||
if args.use_weight_only_with_precision.endswith("awq"):
|
||
quant_kwargs.update(has_zero_point=False,
|
||
pre_quant_scale=True,
|
||
exclude_modules=["lm_head"])
|
||
|
||
trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig(
|
||
architecture="GemmaForCausalLM",
|
||
dtype=args.dtype or ckpt_params_dtype,
|
||
logits_dtype="float32",
|
||
vocab_size=ckpt_config.num_embed,
|
||
max_position_embeddings=8192,
|
||
hidden_size=ckpt_config.embed_dim,
|
||
num_hidden_layers=ckpt_config.num_layers,
|
||
num_attention_heads=ckpt_config.num_heads,
|
||
num_key_value_heads=ckpt_config.num_kv_heads,
|
||
head_size=ckpt_config.head_dim,
|
||
hidden_act="gelu",
|
||
intermediate_size=ckpt_config.hidden_dim,
|
||
norm_epsilon=1e-6, # hard-coded in RMSNorm from gemma/layers.py
|
||
position_embedding_type="rope_gpt_neox",
|
||
world_size=args.world_size,
|
||
tp_size=args.world_size,
|
||
pp_size=1,
|
||
quantization=quant_kwargs,
|
||
)
|
||
|
||
trt_llm_config_dict = trt_llm_config.to_dict()
|
||
print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}")
|
||
|
||
config_path = args.output_model_dir / "config.json"
|
||
config_path.parent.mkdir(exist_ok=True, parents=True)
|
||
LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}")
|
||
with config_path.open("w") as config_file:
|
||
json.dump(trt_llm_config_dict, config_file, indent=4)
|
||
|
||
convert_args = dict(trt_llm_config=trt_llm_config,
|
||
model_dir=args.model_dir,
|
||
ckpt_parser=ckpt_parser)
|
||
convert(0, args, convert_args)
|
||
|
||
elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik))
|
||
print(f"Total time of converting checkpoints: {elapsed}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|