TensorRT-LLMs/tensorrt_llm/models/gemma/convert.py
2025-02-25 21:21:49 +08:00

921 lines
41 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
import re
import time
from argparse import Namespace
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING, Dict, Optional, Union
import h5py
import numpy
import numpy as np
import sentencepiece as sp
import torch
import tensorrt_llm
import tensorrt_llm.models.modeling_utils as tllm_utils
from tensorrt_llm._utils import (np_bfloat16, numpy_to_torch,
str_dtype_to_torch, torch_to_numpy)
from tensorrt_llm.logger import logger
from tensorrt_llm.models.convert_utils import load_calib_dataset
from tensorrt_llm.models.gemma.config import GemmaConfig
from tensorrt_llm.models.gemma.smoothquant import (capture_activation_range,
convert_hf_model,
create_model_from_config,
smooth_model)
from tensorrt_llm.models.gemma.utils import params as gemma_params
from tensorrt_llm.models.gemma.weight import (dummy_weights_awq,
load_from_fp8_gemma,
quantize_fp8_weights)
from tensorrt_llm.quantization.mode import QuantAlgo
Weights = Dict[str, torch.Tensor]
Flattened = Dict[str, np.ndarray]
if TYPE_CHECKING:
import transformers
class JAXParser:
def load_parameters(self,
checkpoint_path: Path,
load_model_on_cpu: bool = False) -> Weights:
checkpoint_path = checkpoint_path.absolute()
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
return gemma_params.nest_params(
gemma_params.param_remapper(
gemma_params.load_params(checkpoint_path)))
def embedding_weights(self, ckpt_params) -> np.ndarray:
return ckpt_params["transformer"]["embedder"]["input_embedding"]
def get_config(self, checkpoint_path, ckpt_params,
num_embed) -> SimpleNamespace:
from tensorrt_llm.models.gemma.utils.transformer import \
TransformerConfig
return TransformerConfig.from_params(ckpt_params, num_embed=num_embed)
def rename_to_trt_llm(self, name: str) -> Optional[str]:
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix, name = name.split(".", maxsplit=1)
assert prefix == "transformer"
sub_patterns = (
(r"embedder.input_embedding", r"vocab_embedding.weight"),
(r"layer_(\d+).pre_attention_norm.scale",
r"layers.\1.input_layernorm.weight"),
(r"layer_(\d+).attn.q_einsum.w", r"layers.\1.attention.qkv.weight"),
(r"layer_(\d+).attn.kv_einsum.w",
None), # drop as kv will be concatenated with q
(r"layer_(\d+).attn.qkv_einsum.w",
r"layers.\1.attention.qkv.weight"),
(r"layer_(\d+).attn.attn_vec_einsum.w",
r"layers.\1.attention.dense.weight"),
(r"layer_(\d+).mlp.gating_einsum", r"layers.\1.mlp.fc.weight"),
(r"layer_(\d+).mlp.linear", r"layers.\1.mlp.proj.weight"),
(r"layer_(\d+).pre_ffw_norm.scale",
r"layers.\1.post_layernorm.weight"),
(r"final_norm.scale", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params) -> Flattened:
import flax.traverse_util
new_params = flax.traverse_util.flatten_dict(params, sep=".")
# if the dtype is bfloat16, cast to float32
for k in new_params:
if new_params[k].dtype != np.float32 and new_params[
k].dtype != np.float16:
new_params[k] = new_params[k].astype(np.float32)
return new_params
class KerasParser:
def load_parameters(self,
checkpoint_path: Path,
load_model_on_cpu: bool = False) -> h5py.File:
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
weights_file = json.load(open(checkpoint_path / config_file))["weights"]
h5_path = checkpoint_path / weights_file
return h5py.File(h5_path, "r")
def embedding_weights(self, ckpt_params) -> np.ndarray:
return np.array(ckpt_params["layers/reversible_embedding/vars/0"])
def get_config(self, checkpoint_path, ckpt_params,
num_embed) -> SimpleNamespace:
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
config_old = json.load(open(checkpoint_path / config_file))["config"]
config_new = SimpleNamespace()
config_new.num_layers = config_old["num_layers"]
config_new.num_embed = config_old["vocabulary_size"]
config_new.embed_dim = config_old["hidden_dim"]
config_new.hidden_dim = config_old["intermediate_dim"] // 2
config_new.num_heads = config_old["num_query_heads"]
config_new.head_dim = config_old["head_dim"]
config_new.num_kv_heads = config_old["num_key_value_heads"]
return config_new
def rename_to_trt_llm(self, name: str) -> Optional[str]:
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
name = name.replace("/gemma_decoder_block/", "/gemma_decoder_block_0/")
sub_patterns = (
(r"layers/reversible_embedding/vars/0", r"vocab_embedding.weight"),
(r"layers/gemma_decoder_block_(\d+)/pre_attention_norm/vars/0",
r"layers.\1.input_layernorm.weight"),
(r"layers/gemma_decoder_block_(\d+)/attention/query_dense/vars/0",
r"layers.\1.attention.qkv.weight"),
(r"layers/gemma_decoder_block_(\d+)/attention/key_dense/vars/0",
None), # drop as k will be concatenated with q
(r"layers/gemma_decoder_block_(\d+)/attention/value_dense/vars/0",
None), # drop as v will be concatenated with q
(r"layers/gemma_decoder_block_(\d+)/attention/output_dense/vars/0",
r"layers.\1.attention.dense.weight"),
(r"layers/gemma_decoder_block_(\d+)/gating_ffw/vars/0",
r"layers.\1.mlp.fc.weight"),
(r"layers/gemma_decoder_block_(\d+)/gating_ffw_2/vars/0",
None), # merged with above
(r"layers/gemma_decoder_block_(\d+)/ffw_linear/vars/0",
r"layers.\1.mlp.proj.weight"),
(r"layers/gemma_decoder_block_(\d+)/pre_ffw_norm/vars/0",
r"layers.\1.post_layernorm.weight"),
(r"layers/rms_normalization/vars/0", r"ln_f.weight"),
(r"optimizer/vars/(\d+)", None), # Not used
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params) -> Flattened:
f_params = {}
def walk(name, obj):
if isinstance(obj, h5py.Dataset):
if obj.dtype == "|V2":
# bfloat16 case
f_params[name] = torch_to_numpy(
numpy_to_torch(np.array(obj).astype(np_bfloat16)).to(
torch.float32))
else:
f_params[name] = np.array(obj)
params.visititems(walk)
return f_params
class TorchParser:
def load_parameters(self,
checkpoint_path: Path,
load_model_on_cpu: bool = False) -> Weights:
ckpt_path = list(checkpoint_path.glob('*.ckpt'))[0]
model_params = torch.load(ckpt_path)['model_state_dict']
model_params.pop('freqs_cis')
return model_params
def embedding_weights(self, ckpt_params) -> np.ndarray:
return ckpt_params["embedder.weight"]
def get_config(self, checkpoint_path, ckpt_params,
num_embed) -> SimpleNamespace:
checkpoint_path = checkpoint_path.absolute()
config_file = "config.json"
with open(checkpoint_path / config_file, 'r') as f:
json_str = f.read()
json_str = json_str.replace("'", "\"")
json_str = json_str.replace(",\n}", "\n}")
config_old = json.loads(json_str)
config_new = SimpleNamespace()
config_new.num_layers = config_old["num_hidden_layers"]
config_new.num_embed = config_old["vocab_size"]
config_new.embed_dim = config_old["hidden_size"]
config_new.hidden_dim = config_old["intermediate_size"]
config_new.num_heads = config_old["num_attention_heads"]
config_new.head_dim = config_old["head_dim"]
config_new.num_kv_heads = config_old["num_key_value_heads"]
return config_new
def rename_to_trt_llm(self, name: str) -> Optional[str]:
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
sub_patterns = (
(r"embedder.weight", r"vocab_embedding.weight"),
(r"model.layers.(\d+).input_layernorm.weight",
r"layers.\1.input_layernorm.weight"),
(r"model.layers.(\d+).self_attn.qkv_proj.weight",
r"layers.\1.attention.qkv.weight"),
(r"model.layers.(\d+).self_attn.o_proj.weight",
r"layers.\1.attention.dense.weight"),
(r"model.layers.(\d+).mlp.gate_proj.weight",
r"layers.\1.mlp.fc.weight"),
(r"model.layers.(\d+).mlp.up_proj.weight",
None), # merged with above
(r"model.layers.(\d+).mlp.down_proj.weight",
r"layers.\1.mlp.proj.weight"),
(r"model.layers.(\d+).post_attention_layernorm.weight",
r"layers.\1.post_layernorm.weight"),
(r"model.norm.weight", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {name}")
def flatten_params(self, params) -> Flattened:
f_params = {}
for k, v in params.items():
if v.dtype == torch.bfloat16:
v = v.float()
f_params[k] = torch_to_numpy(v)
return f_params
class HfParser:
def load_parameters(self,
checkpoint_path: Path,
load_model_on_cpu: bool = False) -> Weights:
"""`AutoModelForCausalLM.from_pretrained` will parse the correct gemma, whether Gemma or Gemma2 or future versions."""
from transformers import AutoModelForCausalLM
hf_model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
device_map="cpu" if load_model_on_cpu else "auto",
torch_dtype='auto',
trust_remote_code=True,
)
model_params = dict(hf_model.named_parameters())
return model_params
def embedding_weights(self, ckpt_params) -> np.ndarray:
raise RuntimeError(
"This method shouldn't be called - `GemmaConfig.from_hugging_face` takes care of this."
)
def get_config(self, checkpoint_path, ckpt_params,
num_embed) -> SimpleNamespace:
raise RuntimeError(
"This method shouldn't be called - `GemmaConfig.from_hugging_face` takes care of this."
)
def rename_to_trt_llm(self, name: str) -> Optional[str]:
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
prefix = "transformer"
sub_patterns = (
(r"model.embed_tokens.weight", r"vocab_embedding.weight"),
(r"model.layers.(\d+).input_layernorm.weight",
r"layers.\1.input_layernorm.weight"),
(r"model.layers.(\d+).self_attn.q_proj.weight",
r"layers.\1.attention.qkv.weight"),
(r"model.layers.(\d+).self_attn.k_proj.weight",
None), # merged with above
(r"model.layers.(\d+).self_attn.v_proj.weight",
None), # merged with above
(r"model.layers.(\d+).self_attn.o_proj.weight",
r"layers.\1.attention.dense.weight"),
(r"model.layers.(\d+).mlp.gate_proj.weight",
r"layers.\1.mlp.fc.weight"),
(r"model.layers.(\d+).mlp.up_proj.weight",
None), # merged with above
(r"model.layers.(\d+).mlp.down_proj.weight",
r"layers.\1.mlp.proj.weight"),
(r"model.layers.(\d+).post_attention_layernorm.weight",
r"layers.\1.post_layernorm.weight"),
(r"model.layers.(\d+).pre_feedforward_layernorm.weight",
r"layers.\1.pre_feedforward_layernorm.weight"),
(r"model.layers.(\d+).post_feedforward_layernorm.weight",
r"layers.\1.post_feedforward_layernorm.weight"),
(r"model.norm.weight", r"ln_f.weight"),
)
for source, target in sub_patterns:
if re.match(source, name):
if target is None:
return target
else:
name = re.sub(source, target, name)
return ".".join((prefix, name))
else:
raise ValueError(f"Don't know how to rename {prefix}.{name}")
def flatten_params(self, params: Weights) -> Dict[str, numpy.ndarray]:
f_params = {}
for k, v in params.items():
if v.dtype == torch.bfloat16:
v = v.float()
f_params[k] = torch_to_numpy(v)
return f_params
def split(v: np.ndarray, tp_size: int, idx: int, dim: int = 0) -> np.ndarray:
if tp_size == 1:
return v
return np.split(v, tp_size, axis=dim)[idx]
def split_matrix_tp(v: np.ndarray, tensor_parallel: int, rank: int,
dim: int) -> np.ndarray:
return split(v, tensor_parallel, rank, dim=dim)
def add_trt_llm_weight(weights: Flattened,
name: str,
param: np.ndarray,
dtype: Optional[np.dtype] = None):
assert name not in weights, f"{name} is already added."
param = numpy_to_torch(param)
if dtype is not None:
assert isinstance(dtype,
str), f"dtype must be str, but get type {type(dtype)}"
param = param.to(str_dtype_to_torch(dtype))
weights[name] = param.contiguous()
def quantize(param: np.ndarray,
quant_mode: tensorrt_llm.quantization.QuantMode):
if quant_mode.is_int8_weight_only():
quant_dtype = torch.int8
elif quant_mode.is_int4_weight_only():
quant_dtype = torch.quint4x2
else:
raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
param = numpy_to_torch(param)
param = param.t().contiguous()
# previously this fn was available in torch.ops.fastertransformer namespace
(
quantized_weights,
scales,
) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
param, quant_dtype)
if scales.dtype == torch.bfloat16:
scales = scales.to(torch.float32).numpy().astype("bfloat16")
else:
scales = scales.numpy()
return quantized_weights.numpy(), scales
Parsers = Union[JAXParser, KerasParser, TorchParser, HfParser]
def load_gemma_weights(
*,
trt_llm_config: tllm_utils.PretrainedConfig,
parameters_or_model_dir: "Union[Weights, Path]",
ckpt_parser: Parsers,
load_model_on_cpu: bool = True,
):
print("Loading weights...")
tik = time.time()
tp_rank = trt_llm_config.mapping.tp_rank
tp_size = trt_llm_config.mapping.tp_size
hidden_size = trt_llm_config.hidden_size
head_dim = trt_llm_config.head_size
weights = {}
if isinstance(parameters_or_model_dir, Path):
logger.debug(f"Loading directory {str(parameters_or_model_dir)}...")
model_params = ckpt_parser.load_parameters(
parameters_or_model_dir,
load_model_on_cpu=load_model_on_cpu,
)
else:
model_params = parameters_or_model_dir
model_params = ckpt_parser.flatten_params(model_params)
for name, param in model_params.items():
logger.debug(f"Converting weight {name}...")
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
if trt_llm_name is None: # omit as used with other params
continue
if "attn.q_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert gqa_mode
# initial shape: (num_q_heads, hidden_size, head_dim)
q_param = param.transpose(1, 0, 2)
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
kv_name = name.replace("q_einsum", "kv_einsum")
kv_param = model_params[kv_name]
kv_param = kv_param.reshape(
trt_llm_config.num_key_value_heads * 2,
hidden_size,
head_dim,
).transpose(1, 0, 2)
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
qkv_param = np.concatenate([q_param, kv_param], axis=1)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
qkv_param = qkv_param.transpose(1, 0)
# If int8 kv enabled, weight-only quantization will be done later.
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "self_attn.qkv_proj" in name:
q_param, k_param, v_param = np.split(param, [
trt_llm_config.num_attention_heads * trt_llm_config.head_size,
trt_llm_config.num_attention_heads * trt_llm_config.head_size +
trt_llm_config.num_key_value_heads * trt_llm_config.head_size
],
axis=0)
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
if not gqa_mode:
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_per_group_scaling():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attn.qkv_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert not gqa_mode
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
qkv_param = param.transpose(0, 1, 3, 2)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
qkv_param.shape[3])
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() \
and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attention/query_dense" in name:
# Keras specific KQV convert
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
if gqa_mode:
# initial shape: (num_q_heads, hidden_size, head_dim)
q_param = param.transpose(1, 0, 2)
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
k_name = name.replace("query", "key")
k_param = model_params[k_name]
v_name = name.replace("query", "value")
v_param = model_params[v_name]
kv_param = np.stack((k_param, v_param), axis=0)
kv_param = kv_param.reshape(
trt_llm_config.num_key_value_heads * 2,
hidden_size,
head_dim,
).transpose(1, 0, 2)
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
qkv_param = np.concatenate([q_param, kv_param], axis=1)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
qkv_param = qkv_param.transpose(1, 0)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
else:
q_param = param
k_name = name.replace("query", "key")
k_param = model_params[k_name]
v_name = name.replace("query", "value")
v_param = model_params[v_name]
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
qkv_param = np.stack((q_param, k_param, v_param), axis=0)
qkv_param = qkv_param.transpose(0, 1, 3, 2)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
qkv_param.shape[3])
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name,
qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "q_proj" in name:
mqa_mode = trt_llm_config.num_key_value_heads == 1
if mqa_mode:
# initial shape: (num_heads * head_dim, hidden_size)
q_param = param
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
k_name = name.replace("q_proj", "k_proj")
k_param = model_params[k_name]
v_name = name.replace("q_proj", "v_proj")
v_param = model_params[v_name]
else:
# initial shape: (num_heads * head_dim, hidden_size)
q_param = param
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
k_name = name.replace("q_proj", "k_proj")
k_param = model_params[k_name]
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
v_name = name.replace("q_proj", "v_proj")
v_param = model_params[v_name]
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
# If int8 kv enabled, weight-only quantization will be done later.
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
qkv_param_quantized, qkv_param_scales = quantize(
qkv_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
qkv_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
trt_llm_config.dtype)
elif "attention.dense.weight" in trt_llm_name:
# initial shape: (num_heads, head_dim, hidden_size)
if len(param.shape) == 3:
param = param.reshape(-1, param.shape[2])
param = param.transpose(
1, 0) # (hidden_size, num_heads * head_dum)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
if trt_llm_config.quant_mode.is_weight_only(
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "mlp.fc.weight" in trt_llm_name:
if isinstance(ckpt_parser, KerasParser):
# initial shape: (hidden_size, intermediate_size)
fc_param, gate_param = param, model_params[name.replace(
"gating_ffw", "gating_ffw_2")]
elif isinstance(ckpt_parser, TorchParser):
# initial shape: (intermediate_size, hidden_size)
fc_param, gate_param = param, model_params[name.replace(
"mlp.gate_proj", "mlp.up_proj")]
fc_param = fc_param.transpose(1, 0)
gate_param = gate_param.transpose(1, 0)
elif isinstance(ckpt_parser, HfParser):
# initial shape: (intermediate_size, hidden_size)
fc_param, gate_param = param, model_params[name.replace(
"mlp.gate_proj", "mlp.up_proj")]
fc_param = fc_param.transpose(1, 0)
gate_param = gate_param.transpose(1, 0)
else:
# initial shape: (2, hidden_size, intermediate_size)
fc_param, gate_param = param[0], param[1]
fc_param = fc_param.transpose(1, 0)
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
fc_param_quantized, fc_param_scales = quantize(
fc_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, fc_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
fc_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, fc_param,
trt_llm_config.dtype)
gate_param = gate_param.transpose(1, 0)
gate_param = split_matrix_tp(gate_param, tp_size, tp_rank, dim=0)
trt_llm_name = trt_llm_name.replace("mlp.fc.weight",
"mlp.gate.weight")
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
gate_param_quantized, gate_param_scales = quantize(
gate_param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, gate_param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
gate_param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, gate_param,
trt_llm_config.dtype)
elif "mlp.proj.weight" in trt_llm_name:
if not isinstance(ckpt_parser, TorchParser) and not isinstance(
ckpt_parser, HfParser):
# initial shape: (intermediate_size, hidden_size)
param = param.transpose(1, 0)
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache():
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_channel_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name \
or "embed_tokens.weight" in name:
# TODO: safetensor doesn't allow to save a shared tensor.
# Currently, we clone the weight but to save the disk, it
# would be better to skip saving lm_head weights and
# handle it at the loading phase.
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
add_trt_llm_weight(weights, "lm_head.weight", np.copy(lm_head),
trt_llm_config.dtype)
if trt_llm_config.use_parallel_embedding:
assert trt_llm_config.vocab_size % tp_size == 0
param = split_matrix_tp(
param,
tp_size,
tp_rank,
dim=trt_llm_config.embedding_sharding_dim,
)
if trt_llm_config.quant_mode.is_int8_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
not trt_llm_config.quant_mode.has_int8_kv_cache() and trt_llm_config.quantization.exclude_modules is not None:
# shape of embedding table: [V, K], V: vocab size, K: embedding dim
# quantize will do following work:
# 1. transpose the input from [V, K] to [K, V]
# 2. compute V scales across dimension K
# 3. quantize the input to int8_weight
# 4. transpose the int8_weight to [V, K], but there is a bug in 'quantize' that the claimed shape is [K, V]
param_quantized, param_scales = quantize(
param, trt_llm_config.quant_mode)
# Reshape the [K, V] to [V, K] to match the real data layout.
param_quantized = np.ascontiguousarray(
param_quantized.reshape(
[param_quantized.shape[1], param_quantized.shape[0]]))
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
add_trt_llm_weight(
weights,
trt_llm_name.replace(".weight", ".per_token_scale"),
param_scales,
trt_llm_config.dtype,
)
else:
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
elif any(keyword in name for keyword in (
"pre_attention_norm.scale",
"pre_ffw_norm.scale",
"final_norm.scale",
"pre_attention_norm/vars/0",
"pre_ffw_norm/vars/0",
"rms_normalization/vars/0",
"input_layernorm",
"post_attention_layernorm",
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
"model.norm.weight",
)):
param = param + 1.0 # upcasted to float32 in case of bfloat16
add_trt_llm_weight(weights, trt_llm_name, param,
trt_llm_config.dtype)
else:
raise RuntimeError(f"Unhandled {name} module weights")
del model_params
print(
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
)
return weights
@dataclass(frozen=True, kw_only=True)
class QuantizeModifiers:
"""
Bag of additional conversion parameters, sourced from argparse or from defaults.
We may use this class to easily create wrappers of `convert_gemma`
(Otherwise, we would need to repeat these params many times or spread `**kwargs` in an untypesafe manner)
"""
use_weight_only_with_precision: Optional[str] = None
per_channel: bool = False
per_token: bool = False
calib_dataset: str = "ccdv/cnn_dailymail"
tokenizer_dir: Optional[str] = None
enable_fp8: bool = False
fp8_kv_cache: bool = False
quant_ckpt_path: Optional[str] = None
@classmethod
def from_args(cls, args: Namespace) -> "QuantizeModifiers":
return cls(**{
k: v
for k, v in vars(args).items() if k in cls.__annotations__
})
def non_modelopt_quantize_if_needed(
weights: Weights, *, model_dir: Path,
quantize_modifiers: QuantizeModifiers,
trt_llm_config: Union[GemmaConfig,
tllm_utils.PretrainedConfig]) -> Weights:
tokenizer_dir = quantize_modifiers.tokenizer_dir or model_dir
quant_cfg = trt_llm_config.quantization
kv_cache_is_int8 = quant_cfg.kv_cache_quant_algo == QuantAlgo.INT8
use_smooth_quant = quant_cfg._use_plugin_sq
if use_smooth_quant or kv_cache_is_int8:
qkv_para = {}
smoother = {}
dataset = load_calib_dataset(quantize_modifiers.calib_dataset)
tokenizer = sp.SentencePieceProcessor(model_file=tokenizer_dir)
if "transformer.vocab_embedding.weight" in weights:
# To use the HF to do SmoothQuant, we need to scale the embedding.
weights["transformer.vocab_embedding.weight"] = torch.multiply(
weights["transformer.vocab_embedding.weight"].to(torch.float32),
math.sqrt(trt_llm_config.hidden_size),
)
hf_model = create_model_from_config(trt_llm_config, weights)
act_range = capture_activation_range(hf_model, tokenizer, dataset)
if use_smooth_quant:
smooth_model(hf_model, act_range, quant_cfg.smoothquant_val,
qkv_para, smoother)
weights = convert_hf_model(
hf_model=hf_model,
mapping=trt_llm_config.mapping,
vocab_size=trt_llm_config.vocab_size or 32000,
dtype=trt_llm_config.dtype,
use_parallel_embedding=trt_llm_config.use_parallel_embedding,
sharding_dim=0,
use_weight_only=quantize_modifiers.use_weight_only_with_precision
is not None,
plugin_weight_only_quant_type=torch.int8
if quantize_modifiers.use_weight_only_with_precision == 'int8' else
torch.quint4x2,
use_smooth_quant=use_smooth_quant,
per_channel=quantize_modifiers.per_channel,
per_token=quantize_modifiers.per_token,
int8_kv_cache=kv_cache_is_int8,
act_range=act_range,
qkv_para=qkv_para,
smoother=smoother)
if "transformer.vocab_embedding.weight" in weights:
# Revert the scaling of embedding
weights["transformer.vocab_embedding.weight"] = torch.divide(
weights["transformer.vocab_embedding.weight"].to(torch.float32),
math.sqrt(trt_llm_config.hidden_size),
).to(str_dtype_to_torch(trt_llm_config.dtype))
if "lm_head.weight" in weights: # Remove lm_head before saving it in unified checkpoint.
del weights["lm_head.weight"]
return weights
if quantize_modifiers.use_weight_only_with_precision and quantize_modifiers.use_weight_only_with_precision.endswith(
"awq"):
weights = dummy_weights_awq(
weights=weights,
precision=quantize_modifiers.use_weight_only_with_precision,
trt_llm_config=trt_llm_config,
group_size=128)
elif quantize_modifiers.enable_fp8 or quantize_modifiers.fp8_kv_cache:
weight_scales = quantize_fp8_weights(weights,
trt_llm_config.num_hidden_layers,
trt_llm_config.mapping)
scales = load_from_fp8_gemma(quantize_modifiers.quant_ckpt_path,
trt_llm_config.num_hidden_layers,
trt_llm_config.mapping,
quantize_modifiers.fp8_kv_cache,
weight_scales)
weights.update(scales)
return weights
def load_gemma_weights_from_hf_model(
hf_model: "transformers.AutoModelForCausalLM",
config: GemmaConfig) -> Weights:
return load_gemma_weights(parameters_or_model_dir=dict(
hf_model.named_parameters()),
trt_llm_config=config,
ckpt_parser=HfParser())