mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
455 lines
19 KiB
Python
455 lines
19 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import time
|
|
from operator import attrgetter
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.logger as logger
|
|
from tensorrt_llm.models import GPTJForCausalLM
|
|
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
|
|
def get_scaling_factors(
|
|
model_path: Union[str, Path],
|
|
num_layers: int,
|
|
quant_mode: Optional[QuantMode] = None,
|
|
) -> Optional[Dict[str, List[int]]]:
|
|
""" Get the scaling factors for GPT-J model
|
|
|
|
Returns a dictionary of scaling factors for the selected layers of the
|
|
GPT-J model.
|
|
|
|
Args:
|
|
model_path (str): Path to the quantized GPT-J model
|
|
layers (list): List of layers to get the scaling factors for. If None,
|
|
all layers are selected.
|
|
|
|
Returns:
|
|
dict: Dictionary of scaling factors for the selected layers of the
|
|
GPT-J model.
|
|
|
|
example:
|
|
|
|
{
|
|
'qkv_act': qkv_act_scale,
|
|
'qkv_weights': qkv_weights_scale,
|
|
'qkv_output' : qkv_outputs_scale,
|
|
'dense_act': dense_act_scale,
|
|
'dense_weights': dense_weights_scale,
|
|
'fc_act': fc_act_scale,
|
|
'fc_weights': fc_weights_scale,
|
|
'proj_act': proj_act_scale,
|
|
'proj_weights': proj_weights_scale,
|
|
}
|
|
"""
|
|
|
|
if model_path is None:
|
|
logger.warning(f"--quantized_fp8_model_path not specified. "
|
|
f"Initialize quantization scales automatically.")
|
|
return get_dummy_quant_scales(num_layers)
|
|
weight_dict = np.load(model_path)
|
|
|
|
# yapf: disable
|
|
scaling_factor = {
|
|
'qkv_act': [],
|
|
'qkv_weights': [],
|
|
'qkv_output': [],
|
|
'dense_act': [],
|
|
'dense_weights': [],
|
|
'fc_act': [],
|
|
'fc_weights': [],
|
|
'proj_act': [],
|
|
'proj_weights': [],
|
|
}
|
|
|
|
for layer in range(num_layers):
|
|
scaling_factor['qkv_act'].append(max(
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
|
|
))
|
|
scaling_factor['qkv_weights'].append(max(
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
|
|
))
|
|
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
|
|
# Not calibrarting KV cache.
|
|
scaling_factor['qkv_output'].append(1.0)
|
|
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
|
|
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
|
|
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
|
|
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
|
|
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
|
|
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
|
|
# yapf: enable
|
|
for k, v in scaling_factor.items():
|
|
assert len(v) == num_layers, \
|
|
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
|
|
|
|
return scaling_factor
|
|
|
|
|
|
def load_from_hf_gpt_j(tensorrt_llm_gpt_j: GPTJForCausalLM,
|
|
hf_gpt_j,
|
|
fp16=False,
|
|
scaling_factors=None):
|
|
|
|
hf_model_gptj_block_names = [
|
|
"ln_1.weight",
|
|
"ln_1.bias",
|
|
"mlp.fc_in.weight",
|
|
"mlp.fc_in.bias",
|
|
"mlp.fc_out.weight",
|
|
"mlp.fc_out.bias",
|
|
]
|
|
|
|
tensorrt_llm_model_gptj_block_names = [
|
|
"input_layernorm.weight",
|
|
"input_layernorm.bias",
|
|
"mlp.fc.weight",
|
|
"mlp.fc.bias",
|
|
"mlp.proj.weight",
|
|
"mlp.proj.bias",
|
|
]
|
|
|
|
quant_mode = getattr(tensorrt_llm_gpt_j, 'quant_mode', QuantMode(0))
|
|
|
|
tensorrt_llm.logger.info('Loading weights from HF GPT-J...')
|
|
tik = time.time()
|
|
|
|
torch_dtype = torch.float16 if fp16 else torch.float32
|
|
hf_gpt_j_state_dict = hf_gpt_j.state_dict()
|
|
|
|
v = hf_gpt_j_state_dict.get('transformer.wte.weight')
|
|
tensorrt_llm_gpt_j.embedding.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
n_layer = hf_gpt_j.config.n_layer
|
|
|
|
for layer_idx in range(n_layer):
|
|
prefix = "transformer.h." + str(layer_idx) + "."
|
|
for idx, hf_attr in enumerate(hf_model_gptj_block_names):
|
|
v = hf_gpt_j_state_dict.get(prefix + hf_attr)
|
|
layer = attrgetter(tensorrt_llm_model_gptj_block_names[idx])(
|
|
tensorrt_llm_gpt_j.layers[layer_idx])
|
|
if idx == 2 and scaling_factors:
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].mlp.fc.activation_scaling_factor.value = np.array(
|
|
[scaling_factors['fc_act'][layer_idx]],
|
|
dtype=np.float32)
|
|
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].mlp.fc.weights_scaling_factor.value = np.array(
|
|
[scaling_factors['fc_weights'][layer_idx]],
|
|
dtype=np.float32)
|
|
|
|
elif idx == 4 and scaling_factors:
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].mlp.proj.activation_scaling_factor.value = np.array(
|
|
[scaling_factors['proj_act'][layer_idx]],
|
|
dtype=np.float32)
|
|
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].mlp.proj.weights_scaling_factor.value = np.array(
|
|
[scaling_factors['proj_weights'][layer_idx]],
|
|
dtype=np.float32)
|
|
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
|
|
|
|
# Attention QKV Linear
|
|
# concatenate the Q, K, V layers weights.
|
|
q_weights = hf_gpt_j_state_dict.get(prefix + "attn.q_proj.weight")
|
|
k_weights = hf_gpt_j_state_dict.get(prefix + "attn.k_proj.weight")
|
|
v_weights = hf_gpt_j_state_dict.get(prefix + "attn.v_proj.weight")
|
|
qkv_weights = torch.cat((q_weights, k_weights, v_weights))
|
|
layer = attrgetter("attention.qkv.weight")(
|
|
tensorrt_llm_gpt_j.layers[layer_idx])
|
|
setattr(layer, "value", qkv_weights.to(torch_dtype).cpu().numpy())
|
|
if scaling_factors:
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].attention.qkv.activation_scaling_factor.value = np.array(
|
|
[scaling_factors['qkv_act'][layer_idx]], dtype=np.float32)
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].attention.qkv.weights_scaling_factor.value = np.array(
|
|
[scaling_factors['qkv_weights'][layer_idx]],
|
|
dtype=np.float32)
|
|
|
|
if quant_mode.has_fp8_kv_cache():
|
|
if scaling_factors:
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].attention.kv_orig_quant_scale.value = np.array(
|
|
[scaling_factors['qkv_output'][layer_idx]],
|
|
dtype=np.float32)
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].attention.kv_quant_orig_scale.value = np.array(
|
|
[1.0 / scaling_factors['qkv_output'][layer_idx]],
|
|
dtype=np.float32)
|
|
|
|
# Attention Dense (out_proj) Linear
|
|
v = hf_gpt_j_state_dict.get(prefix + "attn.out_proj.weight")
|
|
layer = attrgetter("attention.dense.weight")(
|
|
tensorrt_llm_gpt_j.layers[layer_idx])
|
|
setattr(layer, "value", v.to(torch_dtype).cpu().numpy())
|
|
if scaling_factors:
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].attention.dense.activation_scaling_factor.value = np.array(
|
|
[scaling_factors['dense_act'][layer_idx]], dtype=np.float32)
|
|
tensorrt_llm_gpt_j.layers[
|
|
layer_idx].attention.dense.weights_scaling_factor.value = np.array(
|
|
[scaling_factors['dense_weights'][layer_idx]],
|
|
dtype=np.float32)
|
|
|
|
v = hf_gpt_j_state_dict.get('transformer.ln_f.weight')
|
|
tensorrt_llm_gpt_j.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
v = hf_gpt_j_state_dict.get('transformer.ln_f.bias')
|
|
tensorrt_llm_gpt_j.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
v = hf_gpt_j_state_dict.get('lm_head.weight')
|
|
tensorrt_llm_gpt_j.lm_head.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
v = hf_gpt_j_state_dict.get('lm_head.bias')
|
|
tensorrt_llm_gpt_j.lm_head.bias.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
|
|
|
|
|
def AWQ_quantize_pack_preprocess(weight, scale, group_size, packer,
|
|
preprocessor):
|
|
scale = scale.repeat_interleave(group_size, dim=0)
|
|
weight = weight / scale
|
|
weight = torch.round(weight).char()
|
|
weight = torch.where(weight > 7, 7, weight)
|
|
qweight_int8 = torch.where(weight < -8, -8, weight)
|
|
int4_weight = packer(qweight_int8.cpu())
|
|
int4_weight = preprocessor(int4_weight, torch.quint4x2)
|
|
return int4_weight.view(torch.float32).cpu().numpy()
|
|
|
|
|
|
def process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
|
preprocessor, torch_dtype):
|
|
weight = awq_gpt_j[mPrefix + ".weight"].T.contiguous()
|
|
[k, n] = weight.shape
|
|
amax = awq_gpt_j[mPrefix + ".weight_quantizer._amax"].reshape(
|
|
(n, int(k / group_size))).T.contiguous()
|
|
pre_quant_scale = awq_gpt_j[mPrefix +
|
|
".input_quantizer._pre_quant_scale"].reshape(
|
|
(1, k))
|
|
scale = amax / 8.0
|
|
mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale, group_size,
|
|
packer, preprocessor)
|
|
mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
|
|
mOp.pre_quant_scale.value = pre_quant_scale.to(torch_dtype).cpu().numpy()
|
|
|
|
|
|
def deSmooth(weight, pre_quant_scale):
|
|
[k, n] = weight.shape
|
|
pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1,
|
|
0).contiguous()
|
|
weight = weight * pre_quant_scale
|
|
return weight
|
|
|
|
|
|
def reSmooth(weight, pre_quant_scale):
|
|
[k, n] = weight.shape
|
|
pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1,
|
|
0).contiguous()
|
|
weight = weight / pre_quant_scale
|
|
return weight
|
|
|
|
|
|
def get_scale(weight, group_size):
|
|
weight = weight.T.contiguous()
|
|
[n, k] = weight.shape
|
|
weight = weight.reshape(n, int(k / group_size), group_size)
|
|
weight = torch.abs(weight.reshape(-1, group_size))
|
|
amax, idx = weight.max(1)
|
|
amax = amax.reshape(n, int(k / group_size)).T.contiguous()
|
|
return amax / 8
|
|
|
|
|
|
def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale,
|
|
group_size):
|
|
weight = deSmooth(weight, pre_quant_scale)
|
|
weight = reSmooth(weight, avg_pre_quant_scale)
|
|
scale = get_scale(weight, group_size)
|
|
return weight, scale
|
|
|
|
|
|
def process_and_assign_qkv_weight(awq_gpt_j, prefix, mOp, group_size, packer,
|
|
preprocessor, torch_dtype):
|
|
q_weight = awq_gpt_j[prefix + "attn.q_proj.weight"].T.contiguous()
|
|
k_weight = awq_gpt_j[prefix + "attn.k_proj.weight"].T.contiguous()
|
|
v_weight = awq_gpt_j[prefix + "attn.v_proj.weight"].T.contiguous()
|
|
[k, n] = q_weight.shape
|
|
|
|
q_pre_quant_scale = awq_gpt_j[
|
|
prefix + "attn.q_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
|
|
k_pre_quant_scale = awq_gpt_j[
|
|
prefix + "attn.k_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
|
|
v_pre_quant_scale = awq_gpt_j[
|
|
prefix + "attn.v_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
|
|
|
|
qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale +
|
|
v_pre_quant_scale) / 3.0
|
|
q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale,
|
|
qkv_pre_quant_scale, group_size)
|
|
k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale,
|
|
qkv_pre_quant_scale, group_size)
|
|
v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale,
|
|
qkv_pre_quant_scale, group_size)
|
|
|
|
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
|
|
qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
|
|
mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(
|
|
torch_dtype).cpu().numpy()
|
|
mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale,
|
|
group_size, packer,
|
|
preprocessor)
|
|
mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
|
|
|
|
|
|
def load_from_awq_gpt_j(tensorrt_llm_gpt_j: GPTJForCausalLM,
|
|
awq_gpt_j,
|
|
config,
|
|
fp16=False,
|
|
group_size=128):
|
|
|
|
awq_gptj_block_names = [
|
|
"ln_1.weight",
|
|
"ln_1.bias",
|
|
"mlp.fc_in.bias",
|
|
"mlp.fc_out.bias",
|
|
]
|
|
|
|
tensorrt_llm_model_gptj_block_names = [
|
|
"input_layernorm.weight",
|
|
"input_layernorm.bias",
|
|
"mlp.fc.bias",
|
|
"mlp.proj.bias",
|
|
]
|
|
|
|
getattr(tensorrt_llm_gpt_j, 'quant_mode', QuantMode(0))
|
|
|
|
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
|
|
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
|
|
|
|
tensorrt_llm.logger.info('Loading weights from AWQ GPT-J...')
|
|
tik = time.time()
|
|
|
|
torch_dtype = torch.float16 if fp16 else torch.float32
|
|
|
|
#check if we need to pad vocab
|
|
v = awq_gpt_j.get('transformer.wte.weight')
|
|
[vocab_size, k] = v.shape
|
|
pad_vocab = False
|
|
pad_vocab_size = vocab_size
|
|
if vocab_size % 64 != 0:
|
|
pad_vocab = True
|
|
pad_vocab_size = int((vocab_size + 63) / 64) * 64
|
|
if pad_vocab:
|
|
new_v = torch.zeros([pad_vocab_size, k])
|
|
new_v[:vocab_size, :] = v
|
|
v = new_v
|
|
tensorrt_llm_gpt_j.embedding.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
n_layer = config["n_layer"]
|
|
|
|
for layer_idx in range(n_layer):
|
|
prefix = "transformer.h." + str(layer_idx) + "."
|
|
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
|
|
for idx, awq_attr in enumerate(awq_gptj_block_names):
|
|
v = awq_gpt_j[prefix + awq_attr]
|
|
layer = attrgetter(tensorrt_llm_model_gptj_block_names[idx])(
|
|
tensorrt_llm_gpt_j.layers[layer_idx])
|
|
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
|
|
|
|
# Attention QKV Linear
|
|
# concatenate the Q, K, V layers weights.
|
|
process_and_assign_qkv_weight(
|
|
awq_gpt_j, prefix,
|
|
tensorrt_llm_gpt_j.layers[layer_idx].attention.qkv, group_size,
|
|
packer, preprocessor, torch_dtype)
|
|
|
|
# Attention Dense (out_proj) Linear
|
|
mPrefix = prefix + "attn.out_proj"
|
|
mOp = tensorrt_llm_gpt_j.layers[layer_idx].attention.dense
|
|
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
|
preprocessor, torch_dtype)
|
|
|
|
# MLP Dense (mlp.fc) Linear
|
|
mPrefix = prefix + "mlp.fc_in"
|
|
mOp = tensorrt_llm_gpt_j.layers[layer_idx].mlp.fc
|
|
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
|
preprocessor, torch_dtype)
|
|
|
|
# MLP Desne (mlp.proj) Linear
|
|
mPrefix = prefix + "mlp.fc_out"
|
|
mOp = tensorrt_llm_gpt_j.layers[layer_idx].mlp.proj
|
|
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
|
preprocessor, torch_dtype)
|
|
|
|
v = awq_gpt_j['transformer.ln_f.weight']
|
|
tensorrt_llm_gpt_j.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
v = awq_gpt_j['transformer.ln_f.bias']
|
|
tensorrt_llm_gpt_j.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
#lm_head
|
|
if pad_vocab:
|
|
weight = awq_gpt_j['lm_head.weight']
|
|
[vocab_size, k] = weight.shape
|
|
new_weight = torch.zeros([pad_vocab_size, k])
|
|
new_weight[:vocab_size, :] = weight
|
|
new_weight = new_weight.T.contiguous()
|
|
amax = awq_gpt_j['lm_head.weight_quantizer._amax'].reshape(
|
|
[vocab_size, int(k / group_size)])
|
|
new_amax = torch.ones([pad_vocab_size, int(k / group_size)])
|
|
new_amax[:vocab_size, :] = amax
|
|
new_amax = new_amax.T.contiguous()
|
|
new_scale = new_amax / 8
|
|
tensorrt_llm_gpt_j.lm_head.qweight.value = AWQ_quantize_pack_preprocess(
|
|
new_weight, new_scale, group_size, packer, preprocessor)
|
|
tensorrt_llm_gpt_j.lm_head.scale.value = new_scale.to(
|
|
torch_dtype).cpu().numpy()
|
|
tensorrt_llm_gpt_j.lm_head.pre_quant_scale.value = awq_gpt_j[
|
|
'lm_head.input_quantizer._pre_quant_scale'].to(
|
|
torch_dtype).cpu().numpy()
|
|
|
|
bias = awq_gpt_j['lm_head.bias']
|
|
new_bias = torch.zeros([pad_vocab_size])
|
|
new_bias[:vocab_size] = bias
|
|
tensorrt_llm_gpt_j.lm_head.bias.value = new_bias.to(
|
|
torch_dtype).cpu().numpy()
|
|
else:
|
|
mPrefix = "lm_head"
|
|
mOp = tensorrt_llm_gpt_j.lm_head
|
|
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
|
preprocessor, torch_dtype)
|
|
|
|
v = awq_gpt_j['lm_head.bias']
|
|
tensorrt_llm_gpt_j.lm_head.bias.value = v.to(torch_dtype).cpu().numpy()
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|