mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
921 lines
41 KiB
Python
921 lines
41 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import json
|
|
import math
|
|
import os
|
|
import re
|
|
import time
|
|
from argparse import Namespace
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import TYPE_CHECKING, Dict, Optional, Union
|
|
|
|
import h5py
|
|
import numpy
|
|
import numpy as np
|
|
import sentencepiece as sp
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.models.modeling_utils as tllm_utils
|
|
from tensorrt_llm._utils import (np_bfloat16, numpy_to_torch,
|
|
str_dtype_to_torch, torch_to_numpy)
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.models.convert_utils import load_calib_dataset
|
|
from tensorrt_llm.models.gemma.config import GemmaConfig
|
|
from tensorrt_llm.models.gemma.smoothquant import (capture_activation_range,
|
|
convert_hf_model,
|
|
create_model_from_config,
|
|
smooth_model)
|
|
from tensorrt_llm.models.gemma.utils import params as gemma_params
|
|
from tensorrt_llm.models.gemma.weight import (dummy_weights_awq,
|
|
load_from_fp8_gemma,
|
|
quantize_fp8_weights)
|
|
from tensorrt_llm.quantization.mode import QuantAlgo
|
|
|
|
Weights = Dict[str, torch.Tensor]
|
|
Flattened = Dict[str, np.ndarray]
|
|
|
|
if TYPE_CHECKING:
|
|
import transformers
|
|
|
|
|
|
class JAXParser:
|
|
|
|
def load_parameters(self,
|
|
checkpoint_path: Path,
|
|
load_model_on_cpu: bool = False) -> Weights:
|
|
checkpoint_path = checkpoint_path.absolute()
|
|
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
return gemma_params.nest_params(
|
|
gemma_params.param_remapper(
|
|
gemma_params.load_params(checkpoint_path)))
|
|
|
|
def embedding_weights(self, ckpt_params) -> np.ndarray:
|
|
return ckpt_params["transformer"]["embedder"]["input_embedding"]
|
|
|
|
def get_config(self, checkpoint_path, ckpt_params,
|
|
num_embed) -> SimpleNamespace:
|
|
from tensorrt_llm.models.gemma.utils.transformer import \
|
|
TransformerConfig
|
|
|
|
return TransformerConfig.from_params(ckpt_params, num_embed=num_embed)
|
|
|
|
def rename_to_trt_llm(self, name: str) -> Optional[str]:
|
|
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
|
prefix, name = name.split(".", maxsplit=1)
|
|
assert prefix == "transformer"
|
|
sub_patterns = (
|
|
(r"embedder.input_embedding", r"vocab_embedding.weight"),
|
|
(r"layer_(\d+).pre_attention_norm.scale",
|
|
r"layers.\1.input_layernorm.weight"),
|
|
(r"layer_(\d+).attn.q_einsum.w", r"layers.\1.attention.qkv.weight"),
|
|
(r"layer_(\d+).attn.kv_einsum.w",
|
|
None), # drop as kv will be concatenated with q
|
|
(r"layer_(\d+).attn.qkv_einsum.w",
|
|
r"layers.\1.attention.qkv.weight"),
|
|
(r"layer_(\d+).attn.attn_vec_einsum.w",
|
|
r"layers.\1.attention.dense.weight"),
|
|
(r"layer_(\d+).mlp.gating_einsum", r"layers.\1.mlp.fc.weight"),
|
|
(r"layer_(\d+).mlp.linear", r"layers.\1.mlp.proj.weight"),
|
|
(r"layer_(\d+).pre_ffw_norm.scale",
|
|
r"layers.\1.post_layernorm.weight"),
|
|
(r"final_norm.scale", r"ln_f.weight"),
|
|
)
|
|
|
|
for source, target in sub_patterns:
|
|
if re.match(source, name):
|
|
if target is None:
|
|
return target
|
|
else:
|
|
name = re.sub(source, target, name)
|
|
return ".".join((prefix, name))
|
|
else:
|
|
raise ValueError(f"Don't know how to rename {prefix}.{name}")
|
|
|
|
def flatten_params(self, params) -> Flattened:
|
|
import flax.traverse_util
|
|
|
|
new_params = flax.traverse_util.flatten_dict(params, sep=".")
|
|
# if the dtype is bfloat16, cast to float32
|
|
for k in new_params:
|
|
if new_params[k].dtype != np.float32 and new_params[
|
|
k].dtype != np.float16:
|
|
new_params[k] = new_params[k].astype(np.float32)
|
|
return new_params
|
|
|
|
|
|
class KerasParser:
|
|
|
|
def load_parameters(self,
|
|
checkpoint_path: Path,
|
|
load_model_on_cpu: bool = False) -> h5py.File:
|
|
checkpoint_path = checkpoint_path.absolute()
|
|
config_file = "config.json"
|
|
weights_file = json.load(open(checkpoint_path / config_file))["weights"]
|
|
h5_path = checkpoint_path / weights_file
|
|
return h5py.File(h5_path, "r")
|
|
|
|
def embedding_weights(self, ckpt_params) -> np.ndarray:
|
|
return np.array(ckpt_params["layers/reversible_embedding/vars/0"])
|
|
|
|
def get_config(self, checkpoint_path, ckpt_params,
|
|
num_embed) -> SimpleNamespace:
|
|
checkpoint_path = checkpoint_path.absolute()
|
|
config_file = "config.json"
|
|
config_old = json.load(open(checkpoint_path / config_file))["config"]
|
|
config_new = SimpleNamespace()
|
|
config_new.num_layers = config_old["num_layers"]
|
|
config_new.num_embed = config_old["vocabulary_size"]
|
|
config_new.embed_dim = config_old["hidden_dim"]
|
|
config_new.hidden_dim = config_old["intermediate_dim"] // 2
|
|
config_new.num_heads = config_old["num_query_heads"]
|
|
config_new.head_dim = config_old["head_dim"]
|
|
config_new.num_kv_heads = config_old["num_key_value_heads"]
|
|
return config_new
|
|
|
|
def rename_to_trt_llm(self, name: str) -> Optional[str]:
|
|
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
|
prefix = "transformer"
|
|
name = name.replace("/gemma_decoder_block/", "/gemma_decoder_block_0/")
|
|
sub_patterns = (
|
|
(r"layers/reversible_embedding/vars/0", r"vocab_embedding.weight"),
|
|
(r"layers/gemma_decoder_block_(\d+)/pre_attention_norm/vars/0",
|
|
r"layers.\1.input_layernorm.weight"),
|
|
(r"layers/gemma_decoder_block_(\d+)/attention/query_dense/vars/0",
|
|
r"layers.\1.attention.qkv.weight"),
|
|
(r"layers/gemma_decoder_block_(\d+)/attention/key_dense/vars/0",
|
|
None), # drop as k will be concatenated with q
|
|
(r"layers/gemma_decoder_block_(\d+)/attention/value_dense/vars/0",
|
|
None), # drop as v will be concatenated with q
|
|
(r"layers/gemma_decoder_block_(\d+)/attention/output_dense/vars/0",
|
|
r"layers.\1.attention.dense.weight"),
|
|
(r"layers/gemma_decoder_block_(\d+)/gating_ffw/vars/0",
|
|
r"layers.\1.mlp.fc.weight"),
|
|
(r"layers/gemma_decoder_block_(\d+)/gating_ffw_2/vars/0",
|
|
None), # merged with above
|
|
(r"layers/gemma_decoder_block_(\d+)/ffw_linear/vars/0",
|
|
r"layers.\1.mlp.proj.weight"),
|
|
(r"layers/gemma_decoder_block_(\d+)/pre_ffw_norm/vars/0",
|
|
r"layers.\1.post_layernorm.weight"),
|
|
(r"layers/rms_normalization/vars/0", r"ln_f.weight"),
|
|
(r"optimizer/vars/(\d+)", None), # Not used
|
|
)
|
|
|
|
for source, target in sub_patterns:
|
|
if re.match(source, name):
|
|
if target is None:
|
|
return target
|
|
else:
|
|
name = re.sub(source, target, name)
|
|
return ".".join((prefix, name))
|
|
else:
|
|
raise ValueError(f"Don't know how to rename {prefix}.{name}")
|
|
|
|
def flatten_params(self, params) -> Flattened:
|
|
f_params = {}
|
|
|
|
def walk(name, obj):
|
|
if isinstance(obj, h5py.Dataset):
|
|
if obj.dtype == "|V2":
|
|
# bfloat16 case
|
|
f_params[name] = torch_to_numpy(
|
|
numpy_to_torch(np.array(obj).astype(np_bfloat16)).to(
|
|
torch.float32))
|
|
else:
|
|
f_params[name] = np.array(obj)
|
|
|
|
params.visititems(walk)
|
|
return f_params
|
|
|
|
|
|
class TorchParser:
|
|
|
|
def load_parameters(self,
|
|
checkpoint_path: Path,
|
|
load_model_on_cpu: bool = False) -> Weights:
|
|
ckpt_path = list(checkpoint_path.glob('*.ckpt'))[0]
|
|
model_params = torch.load(ckpt_path)['model_state_dict']
|
|
model_params.pop('freqs_cis')
|
|
return model_params
|
|
|
|
def embedding_weights(self, ckpt_params) -> np.ndarray:
|
|
return ckpt_params["embedder.weight"]
|
|
|
|
def get_config(self, checkpoint_path, ckpt_params,
|
|
num_embed) -> SimpleNamespace:
|
|
checkpoint_path = checkpoint_path.absolute()
|
|
config_file = "config.json"
|
|
with open(checkpoint_path / config_file, 'r') as f:
|
|
json_str = f.read()
|
|
json_str = json_str.replace("'", "\"")
|
|
json_str = json_str.replace(",\n}", "\n}")
|
|
config_old = json.loads(json_str)
|
|
config_new = SimpleNamespace()
|
|
config_new.num_layers = config_old["num_hidden_layers"]
|
|
config_new.num_embed = config_old["vocab_size"]
|
|
config_new.embed_dim = config_old["hidden_size"]
|
|
config_new.hidden_dim = config_old["intermediate_size"]
|
|
config_new.num_heads = config_old["num_attention_heads"]
|
|
config_new.head_dim = config_old["head_dim"]
|
|
config_new.num_kv_heads = config_old["num_key_value_heads"]
|
|
return config_new
|
|
|
|
def rename_to_trt_llm(self, name: str) -> Optional[str]:
|
|
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
|
prefix = "transformer"
|
|
sub_patterns = (
|
|
(r"embedder.weight", r"vocab_embedding.weight"),
|
|
(r"model.layers.(\d+).input_layernorm.weight",
|
|
r"layers.\1.input_layernorm.weight"),
|
|
(r"model.layers.(\d+).self_attn.qkv_proj.weight",
|
|
r"layers.\1.attention.qkv.weight"),
|
|
(r"model.layers.(\d+).self_attn.o_proj.weight",
|
|
r"layers.\1.attention.dense.weight"),
|
|
(r"model.layers.(\d+).mlp.gate_proj.weight",
|
|
r"layers.\1.mlp.fc.weight"),
|
|
(r"model.layers.(\d+).mlp.up_proj.weight",
|
|
None), # merged with above
|
|
(r"model.layers.(\d+).mlp.down_proj.weight",
|
|
r"layers.\1.mlp.proj.weight"),
|
|
(r"model.layers.(\d+).post_attention_layernorm.weight",
|
|
r"layers.\1.post_layernorm.weight"),
|
|
(r"model.norm.weight", r"ln_f.weight"),
|
|
)
|
|
|
|
for source, target in sub_patterns:
|
|
if re.match(source, name):
|
|
if target is None:
|
|
return target
|
|
else:
|
|
name = re.sub(source, target, name)
|
|
return ".".join((prefix, name))
|
|
else:
|
|
raise ValueError(f"Don't know how to rename {name}")
|
|
|
|
def flatten_params(self, params) -> Flattened:
|
|
f_params = {}
|
|
for k, v in params.items():
|
|
if v.dtype == torch.bfloat16:
|
|
v = v.float()
|
|
f_params[k] = torch_to_numpy(v)
|
|
return f_params
|
|
|
|
|
|
class HfParser:
|
|
|
|
def load_parameters(self,
|
|
checkpoint_path: Path,
|
|
load_model_on_cpu: bool = False) -> Weights:
|
|
"""`AutoModelForCausalLM.from_pretrained` will parse the correct gemma, whether Gemma or Gemma2 or future versions."""
|
|
from transformers import AutoModelForCausalLM
|
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
checkpoint_path,
|
|
device_map="cpu" if load_model_on_cpu else "auto",
|
|
torch_dtype='auto',
|
|
trust_remote_code=True,
|
|
)
|
|
model_params = dict(hf_model.named_parameters())
|
|
return model_params
|
|
|
|
def embedding_weights(self, ckpt_params) -> np.ndarray:
|
|
raise RuntimeError(
|
|
"This method shouldn't be called - `GemmaConfig.from_hugging_face` takes care of this."
|
|
)
|
|
|
|
def get_config(self, checkpoint_path, ckpt_params,
|
|
num_embed) -> SimpleNamespace:
|
|
raise RuntimeError(
|
|
"This method shouldn't be called - `GemmaConfig.from_hugging_face` takes care of this."
|
|
)
|
|
|
|
def rename_to_trt_llm(self, name: str) -> Optional[str]:
|
|
"""Rename a gemma parameter name by the corresponding TRT-LLM style name."""
|
|
prefix = "transformer"
|
|
sub_patterns = (
|
|
(r"model.embed_tokens.weight", r"vocab_embedding.weight"),
|
|
(r"model.layers.(\d+).input_layernorm.weight",
|
|
r"layers.\1.input_layernorm.weight"),
|
|
(r"model.layers.(\d+).self_attn.q_proj.weight",
|
|
r"layers.\1.attention.qkv.weight"),
|
|
(r"model.layers.(\d+).self_attn.k_proj.weight",
|
|
None), # merged with above
|
|
(r"model.layers.(\d+).self_attn.v_proj.weight",
|
|
None), # merged with above
|
|
(r"model.layers.(\d+).self_attn.o_proj.weight",
|
|
r"layers.\1.attention.dense.weight"),
|
|
(r"model.layers.(\d+).mlp.gate_proj.weight",
|
|
r"layers.\1.mlp.fc.weight"),
|
|
(r"model.layers.(\d+).mlp.up_proj.weight",
|
|
None), # merged with above
|
|
(r"model.layers.(\d+).mlp.down_proj.weight",
|
|
r"layers.\1.mlp.proj.weight"),
|
|
(r"model.layers.(\d+).post_attention_layernorm.weight",
|
|
r"layers.\1.post_layernorm.weight"),
|
|
(r"model.layers.(\d+).pre_feedforward_layernorm.weight",
|
|
r"layers.\1.pre_feedforward_layernorm.weight"),
|
|
(r"model.layers.(\d+).post_feedforward_layernorm.weight",
|
|
r"layers.\1.post_feedforward_layernorm.weight"),
|
|
(r"model.norm.weight", r"ln_f.weight"),
|
|
)
|
|
|
|
for source, target in sub_patterns:
|
|
if re.match(source, name):
|
|
if target is None:
|
|
return target
|
|
else:
|
|
name = re.sub(source, target, name)
|
|
return ".".join((prefix, name))
|
|
else:
|
|
raise ValueError(f"Don't know how to rename {prefix}.{name}")
|
|
|
|
def flatten_params(self, params: Weights) -> Dict[str, numpy.ndarray]:
|
|
f_params = {}
|
|
for k, v in params.items():
|
|
if v.dtype == torch.bfloat16:
|
|
v = v.float()
|
|
f_params[k] = torch_to_numpy(v)
|
|
return f_params
|
|
|
|
|
|
def split(v: np.ndarray, tp_size: int, idx: int, dim: int = 0) -> np.ndarray:
|
|
if tp_size == 1:
|
|
return v
|
|
return np.split(v, tp_size, axis=dim)[idx]
|
|
|
|
|
|
def split_matrix_tp(v: np.ndarray, tensor_parallel: int, rank: int,
|
|
dim: int) -> np.ndarray:
|
|
return split(v, tensor_parallel, rank, dim=dim)
|
|
|
|
|
|
def add_trt_llm_weight(weights: Flattened,
|
|
name: str,
|
|
param: np.ndarray,
|
|
dtype: Optional[np.dtype] = None):
|
|
assert name not in weights, f"{name} is already added."
|
|
param = numpy_to_torch(param)
|
|
if dtype is not None:
|
|
assert isinstance(dtype,
|
|
str), f"dtype must be str, but get type {type(dtype)}"
|
|
param = param.to(str_dtype_to_torch(dtype))
|
|
weights[name] = param.contiguous()
|
|
|
|
|
|
def quantize(param: np.ndarray,
|
|
quant_mode: tensorrt_llm.quantization.QuantMode):
|
|
if quant_mode.is_int8_weight_only():
|
|
quant_dtype = torch.int8
|
|
elif quant_mode.is_int4_weight_only():
|
|
quant_dtype = torch.quint4x2
|
|
else:
|
|
raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
|
|
|
|
param = numpy_to_torch(param)
|
|
param = param.t().contiguous()
|
|
|
|
# previously this fn was available in torch.ops.fastertransformer namespace
|
|
(
|
|
quantized_weights,
|
|
scales,
|
|
) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
|
param, quant_dtype)
|
|
|
|
if scales.dtype == torch.bfloat16:
|
|
scales = scales.to(torch.float32).numpy().astype("bfloat16")
|
|
else:
|
|
scales = scales.numpy()
|
|
return quantized_weights.numpy(), scales
|
|
|
|
|
|
Parsers = Union[JAXParser, KerasParser, TorchParser, HfParser]
|
|
|
|
|
|
def load_gemma_weights(
|
|
*,
|
|
trt_llm_config: tllm_utils.PretrainedConfig,
|
|
parameters_or_model_dir: "Union[Weights, Path]",
|
|
ckpt_parser: Parsers,
|
|
load_model_on_cpu: bool = True,
|
|
):
|
|
print("Loading weights...")
|
|
tik = time.time()
|
|
|
|
tp_rank = trt_llm_config.mapping.tp_rank
|
|
tp_size = trt_llm_config.mapping.tp_size
|
|
hidden_size = trt_llm_config.hidden_size
|
|
head_dim = trt_llm_config.head_size
|
|
|
|
weights = {}
|
|
|
|
if isinstance(parameters_or_model_dir, Path):
|
|
logger.debug(f"Loading directory {str(parameters_or_model_dir)}...")
|
|
model_params = ckpt_parser.load_parameters(
|
|
parameters_or_model_dir,
|
|
load_model_on_cpu=load_model_on_cpu,
|
|
)
|
|
else:
|
|
model_params = parameters_or_model_dir
|
|
|
|
model_params = ckpt_parser.flatten_params(model_params)
|
|
|
|
for name, param in model_params.items():
|
|
logger.debug(f"Converting weight {name}...")
|
|
trt_llm_name = ckpt_parser.rename_to_trt_llm(name)
|
|
if trt_llm_name is None: # omit as used with other params
|
|
continue
|
|
|
|
if "attn.q_einsum" in name:
|
|
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
|
assert gqa_mode
|
|
|
|
# initial shape: (num_q_heads, hidden_size, head_dim)
|
|
q_param = param.transpose(1, 0, 2)
|
|
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
|
|
|
|
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
|
|
kv_name = name.replace("q_einsum", "kv_einsum")
|
|
kv_param = model_params[kv_name]
|
|
kv_param = kv_param.reshape(
|
|
trt_llm_config.num_key_value_heads * 2,
|
|
hidden_size,
|
|
head_dim,
|
|
).transpose(1, 0, 2)
|
|
|
|
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
|
|
qkv_param = np.concatenate([q_param, kv_param], axis=1)
|
|
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
|
|
qkv_param = qkv_param.transpose(1, 0)
|
|
|
|
# If int8 kv enabled, weight-only quantization will be done later.
|
|
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
|
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
qkv_param_quantized, qkv_param_scales = quantize(
|
|
qkv_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
qkv_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
elif "self_attn.qkv_proj" in name:
|
|
q_param, k_param, v_param = np.split(param, [
|
|
trt_llm_config.num_attention_heads * trt_llm_config.head_size,
|
|
trt_llm_config.num_attention_heads * trt_llm_config.head_size +
|
|
trt_llm_config.num_key_value_heads * trt_llm_config.head_size
|
|
],
|
|
axis=0)
|
|
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
|
|
|
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
|
|
if not gqa_mode:
|
|
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
|
|
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
|
|
|
|
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
|
|
if trt_llm_config.quant_mode.is_weight_only(
|
|
) and not trt_llm_config.quant_mode.has_per_group_scaling():
|
|
qkv_param_quantized, qkv_param_scales = quantize(
|
|
qkv_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
qkv_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
elif "attn.qkv_einsum" in name:
|
|
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
|
assert not gqa_mode
|
|
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
|
|
qkv_param = param.transpose(0, 1, 3, 2)
|
|
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
|
|
qkv_param.shape[3])
|
|
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
|
|
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
|
|
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() \
|
|
and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
qkv_param_quantized, qkv_param_scales = quantize(
|
|
qkv_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
qkv_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
elif "attention/query_dense" in name:
|
|
# Keras specific KQV convert
|
|
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
|
|
if gqa_mode:
|
|
|
|
# initial shape: (num_q_heads, hidden_size, head_dim)
|
|
q_param = param.transpose(1, 0, 2)
|
|
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1)
|
|
|
|
# initial shape: (2, num_kv_heads, hidden_size, head_dim)
|
|
k_name = name.replace("query", "key")
|
|
k_param = model_params[k_name]
|
|
v_name = name.replace("query", "value")
|
|
v_param = model_params[v_name]
|
|
kv_param = np.stack((k_param, v_param), axis=0)
|
|
|
|
kv_param = kv_param.reshape(
|
|
trt_llm_config.num_key_value_heads * 2,
|
|
hidden_size,
|
|
head_dim,
|
|
).transpose(1, 0, 2)
|
|
|
|
# -> (hidden_size, num_q_heads / tp_size + 2, head_dim)
|
|
qkv_param = np.concatenate([q_param, kv_param], axis=1)
|
|
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
|
|
qkv_param = qkv_param.transpose(1, 0)
|
|
|
|
if trt_llm_config.quant_mode.is_weight_only(
|
|
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
qkv_param_quantized, qkv_param_scales = quantize(
|
|
qkv_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name,
|
|
qkv_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
qkv_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
else:
|
|
q_param = param
|
|
k_name = name.replace("query", "key")
|
|
k_param = model_params[k_name]
|
|
v_name = name.replace("query", "value")
|
|
v_param = model_params[v_name]
|
|
# initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size]
|
|
qkv_param = np.stack((q_param, k_param, v_param), axis=0)
|
|
qkv_param = qkv_param.transpose(0, 1, 3, 2)
|
|
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1,
|
|
qkv_param.shape[3])
|
|
qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1)
|
|
qkv_param = qkv_param.reshape(-1, qkv_param.shape[2])
|
|
if trt_llm_config.quant_mode.is_weight_only(
|
|
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
qkv_param_quantized, qkv_param_scales = quantize(
|
|
qkv_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name,
|
|
qkv_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
qkv_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
elif "q_proj" in name:
|
|
mqa_mode = trt_llm_config.num_key_value_heads == 1
|
|
|
|
if mqa_mode:
|
|
# initial shape: (num_heads * head_dim, hidden_size)
|
|
q_param = param
|
|
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
|
|
|
|
k_name = name.replace("q_proj", "k_proj")
|
|
k_param = model_params[k_name]
|
|
|
|
v_name = name.replace("q_proj", "v_proj")
|
|
v_param = model_params[v_name]
|
|
else:
|
|
# initial shape: (num_heads * head_dim, hidden_size)
|
|
q_param = param
|
|
q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0)
|
|
|
|
k_name = name.replace("q_proj", "k_proj")
|
|
k_param = model_params[k_name]
|
|
k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0)
|
|
|
|
v_name = name.replace("q_proj", "v_proj")
|
|
v_param = model_params[v_name]
|
|
v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0)
|
|
|
|
qkv_param = np.concatenate([q_param, k_param, v_param], axis=0)
|
|
qkv_param = qkv_param.reshape(qkv_param.shape[0], -1)
|
|
|
|
# If int8 kv enabled, weight-only quantization will be done later.
|
|
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
|
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
qkv_param_quantized, qkv_param_scales = quantize(
|
|
qkv_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
qkv_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, qkv_param,
|
|
trt_llm_config.dtype)
|
|
elif "attention.dense.weight" in trt_llm_name:
|
|
# initial shape: (num_heads, head_dim, hidden_size)
|
|
if len(param.shape) == 3:
|
|
param = param.reshape(-1, param.shape[2])
|
|
param = param.transpose(
|
|
1, 0) # (hidden_size, num_heads * head_dum)
|
|
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
|
|
if trt_llm_config.quant_mode.is_weight_only(
|
|
) and not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
param_quantized, param_scales = quantize(
|
|
param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif "mlp.fc.weight" in trt_llm_name:
|
|
if isinstance(ckpt_parser, KerasParser):
|
|
# initial shape: (hidden_size, intermediate_size)
|
|
fc_param, gate_param = param, model_params[name.replace(
|
|
"gating_ffw", "gating_ffw_2")]
|
|
elif isinstance(ckpt_parser, TorchParser):
|
|
# initial shape: (intermediate_size, hidden_size)
|
|
fc_param, gate_param = param, model_params[name.replace(
|
|
"mlp.gate_proj", "mlp.up_proj")]
|
|
fc_param = fc_param.transpose(1, 0)
|
|
gate_param = gate_param.transpose(1, 0)
|
|
elif isinstance(ckpt_parser, HfParser):
|
|
# initial shape: (intermediate_size, hidden_size)
|
|
fc_param, gate_param = param, model_params[name.replace(
|
|
"mlp.gate_proj", "mlp.up_proj")]
|
|
fc_param = fc_param.transpose(1, 0)
|
|
gate_param = gate_param.transpose(1, 0)
|
|
else:
|
|
# initial shape: (2, hidden_size, intermediate_size)
|
|
fc_param, gate_param = param[0], param[1]
|
|
fc_param = fc_param.transpose(1, 0)
|
|
fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0)
|
|
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
|
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
fc_param_quantized, fc_param_scales = quantize(
|
|
fc_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, fc_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
fc_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, fc_param,
|
|
trt_llm_config.dtype)
|
|
|
|
gate_param = gate_param.transpose(1, 0)
|
|
gate_param = split_matrix_tp(gate_param, tp_size, tp_rank, dim=0)
|
|
trt_llm_name = trt_llm_name.replace("mlp.fc.weight",
|
|
"mlp.gate.weight")
|
|
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
|
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
gate_param_quantized, gate_param_scales = quantize(
|
|
gate_param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, gate_param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
gate_param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, gate_param,
|
|
trt_llm_config.dtype)
|
|
elif "mlp.proj.weight" in trt_llm_name:
|
|
if not isinstance(ckpt_parser, TorchParser) and not isinstance(
|
|
ckpt_parser, HfParser):
|
|
# initial shape: (intermediate_size, hidden_size)
|
|
param = param.transpose(1, 0)
|
|
param = split_matrix_tp(param, tp_size, tp_rank, dim=1)
|
|
if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
|
not trt_llm_config.quant_mode.has_int8_kv_cache():
|
|
param_quantized, param_scales = quantize(
|
|
param, trt_llm_config.quant_mode)
|
|
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_channel_scale"),
|
|
param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name \
|
|
or "embed_tokens.weight" in name:
|
|
# TODO: safetensor doesn't allow to save a shared tensor.
|
|
# Currently, we clone the weight but to save the disk, it
|
|
# would be better to skip saving lm_head weights and
|
|
# handle it at the loading phase.
|
|
lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0)
|
|
add_trt_llm_weight(weights, "lm_head.weight", np.copy(lm_head),
|
|
trt_llm_config.dtype)
|
|
|
|
if trt_llm_config.use_parallel_embedding:
|
|
assert trt_llm_config.vocab_size % tp_size == 0
|
|
param = split_matrix_tp(
|
|
param,
|
|
tp_size,
|
|
tp_rank,
|
|
dim=trt_llm_config.embedding_sharding_dim,
|
|
)
|
|
if trt_llm_config.quant_mode.is_int8_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \
|
|
not trt_llm_config.quant_mode.has_int8_kv_cache() and trt_llm_config.quantization.exclude_modules is not None:
|
|
|
|
# shape of embedding table: [V, K], V: vocab size, K: embedding dim
|
|
|
|
# quantize will do following work:
|
|
# 1. transpose the input from [V, K] to [K, V]
|
|
# 2. compute V scales across dimension K
|
|
# 3. quantize the input to int8_weight
|
|
# 4. transpose the int8_weight to [V, K], but there is a bug in 'quantize' that the claimed shape is [K, V]
|
|
param_quantized, param_scales = quantize(
|
|
param, trt_llm_config.quant_mode)
|
|
|
|
# Reshape the [K, V] to [V, K] to match the real data layout.
|
|
param_quantized = np.ascontiguousarray(
|
|
param_quantized.reshape(
|
|
[param_quantized.shape[1], param_quantized.shape[0]]))
|
|
|
|
add_trt_llm_weight(weights, trt_llm_name, param_quantized)
|
|
add_trt_llm_weight(
|
|
weights,
|
|
trt_llm_name.replace(".weight", ".per_token_scale"),
|
|
param_scales,
|
|
trt_llm_config.dtype,
|
|
)
|
|
else:
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
elif any(keyword in name for keyword in (
|
|
"pre_attention_norm.scale",
|
|
"pre_ffw_norm.scale",
|
|
"final_norm.scale",
|
|
"pre_attention_norm/vars/0",
|
|
"pre_ffw_norm/vars/0",
|
|
"rms_normalization/vars/0",
|
|
"input_layernorm",
|
|
"post_attention_layernorm",
|
|
"pre_feedforward_layernorm",
|
|
"post_feedforward_layernorm",
|
|
"model.norm.weight",
|
|
)):
|
|
param = param + 1.0 # upcasted to float32 in case of bfloat16
|
|
add_trt_llm_weight(weights, trt_llm_name, param,
|
|
trt_llm_config.dtype)
|
|
else:
|
|
raise RuntimeError(f"Unhandled {name} module weights")
|
|
del model_params
|
|
|
|
print(
|
|
f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}"
|
|
)
|
|
return weights
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class QuantizeModifiers:
|
|
"""
|
|
Bag of additional conversion parameters, sourced from argparse or from defaults.
|
|
We may use this class to easily create wrappers of `convert_gemma`
|
|
(Otherwise, we would need to repeat these params many times or spread `**kwargs` in an untypesafe manner)
|
|
"""
|
|
|
|
use_weight_only_with_precision: Optional[str] = None
|
|
per_channel: bool = False
|
|
per_token: bool = False
|
|
calib_dataset: str = "ccdv/cnn_dailymail"
|
|
tokenizer_dir: Optional[str] = None
|
|
enable_fp8: bool = False
|
|
fp8_kv_cache: bool = False
|
|
quant_ckpt_path: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_args(cls, args: Namespace) -> "QuantizeModifiers":
|
|
return cls(**{
|
|
k: v
|
|
for k, v in vars(args).items() if k in cls.__annotations__
|
|
})
|
|
|
|
|
|
def non_modelopt_quantize_if_needed(
|
|
weights: Weights, *, model_dir: Path,
|
|
quantize_modifiers: QuantizeModifiers,
|
|
trt_llm_config: Union[GemmaConfig,
|
|
tllm_utils.PretrainedConfig]) -> Weights:
|
|
tokenizer_dir = quantize_modifiers.tokenizer_dir or model_dir
|
|
|
|
quant_cfg = trt_llm_config.quantization
|
|
|
|
kv_cache_is_int8 = quant_cfg.kv_cache_quant_algo == QuantAlgo.INT8
|
|
use_smooth_quant = quant_cfg.use_plugin_sq
|
|
|
|
if use_smooth_quant or kv_cache_is_int8:
|
|
qkv_para = {}
|
|
smoother = {}
|
|
dataset = load_calib_dataset(quantize_modifiers.calib_dataset)
|
|
tokenizer = sp.SentencePieceProcessor(model_file=tokenizer_dir)
|
|
if "transformer.vocab_embedding.weight" in weights:
|
|
# To use the HF to do SmoothQuant, we need to scale the embedding.
|
|
weights["transformer.vocab_embedding.weight"] = torch.multiply(
|
|
weights["transformer.vocab_embedding.weight"].to(torch.float32),
|
|
math.sqrt(trt_llm_config.hidden_size),
|
|
)
|
|
hf_model = create_model_from_config(trt_llm_config, weights)
|
|
act_range = capture_activation_range(hf_model, tokenizer, dataset)
|
|
if use_smooth_quant:
|
|
smooth_model(hf_model, act_range, quant_cfg.smoothquant_val,
|
|
qkv_para, smoother)
|
|
weights = convert_hf_model(
|
|
hf_model=hf_model,
|
|
mapping=trt_llm_config.mapping,
|
|
vocab_size=trt_llm_config.vocab_size or 32000,
|
|
dtype=trt_llm_config.dtype,
|
|
use_parallel_embedding=trt_llm_config.use_parallel_embedding,
|
|
sharding_dim=0,
|
|
use_weight_only=quantize_modifiers.use_weight_only_with_precision
|
|
is not None,
|
|
plugin_weight_only_quant_type=torch.int8
|
|
if quantize_modifiers.use_weight_only_with_precision == 'int8' else
|
|
torch.quint4x2,
|
|
use_smooth_quant=use_smooth_quant,
|
|
per_channel=quantize_modifiers.per_channel,
|
|
per_token=quantize_modifiers.per_token,
|
|
int8_kv_cache=kv_cache_is_int8,
|
|
act_range=act_range,
|
|
qkv_para=qkv_para,
|
|
smoother=smoother)
|
|
if "transformer.vocab_embedding.weight" in weights:
|
|
# Revert the scaling of embedding
|
|
weights["transformer.vocab_embedding.weight"] = torch.divide(
|
|
weights["transformer.vocab_embedding.weight"].to(torch.float32),
|
|
math.sqrt(trt_llm_config.hidden_size),
|
|
).to(str_dtype_to_torch(trt_llm_config.dtype))
|
|
if "lm_head.weight" in weights: # Remove lm_head before saving it in unified checkpoint.
|
|
del weights["lm_head.weight"]
|
|
return weights
|
|
if quantize_modifiers.use_weight_only_with_precision and quantize_modifiers.use_weight_only_with_precision.endswith(
|
|
"awq"):
|
|
weights = dummy_weights_awq(
|
|
weights=weights,
|
|
precision=quantize_modifiers.use_weight_only_with_precision,
|
|
trt_llm_config=trt_llm_config,
|
|
group_size=128)
|
|
elif quantize_modifiers.enable_fp8 or quantize_modifiers.fp8_kv_cache:
|
|
weight_scales = quantize_fp8_weights(weights,
|
|
trt_llm_config.num_hidden_layers,
|
|
trt_llm_config.mapping)
|
|
scales = load_from_fp8_gemma(quantize_modifiers.quant_ckpt_path,
|
|
trt_llm_config.num_hidden_layers,
|
|
trt_llm_config.mapping,
|
|
quantize_modifiers.fp8_kv_cache,
|
|
weight_scales)
|
|
weights.update(scales)
|
|
|
|
return weights
|
|
|
|
|
|
def load_gemma_weights_from_hf_model(
|
|
hf_model: "transformers.AutoModelForCausalLM",
|
|
config: GemmaConfig) -> Weights:
|
|
return load_gemma_weights(parameters_or_model_dir=dict(
|
|
hf_model.named_parameters()),
|
|
trt_llm_config=config,
|
|
ckpt_parser=HfParser())
|