mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-28 22:56:13 +08:00
722 lines
32 KiB
Python
722 lines
32 KiB
Python
import configparser
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._utils import (pad_vocab_size, str_dtype_to_np,
|
|
str_dtype_to_torch)
|
|
from tensorrt_llm.functional import is_gated_activation
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models import GPTLMHeadModel
|
|
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
|
|
def get_scaling_factors(
|
|
model_path: Union[str, Path],
|
|
num_layers: int,
|
|
quant_mode: Optional[QuantMode] = None,
|
|
) -> Optional[Dict[str, List[int]]]:
|
|
""" Get the scaling factors for MPT model
|
|
|
|
Returns a dictionary of scaling factors for the selected layers of the
|
|
MPT model.
|
|
|
|
Args:
|
|
model_path (str): Path to the quantized MPT model
|
|
layers (list): List of layers to get the scaling factors for. If None,
|
|
all layers are selected.
|
|
|
|
Returns:
|
|
dict: Dictionary of scaling factors for the selected layers of the
|
|
mpt model.
|
|
|
|
example:
|
|
|
|
{
|
|
'qkv_act': qkv_act_scale,
|
|
'qkv_weights': qkv_weights_scale,
|
|
'qkv_output' : qkv_outputs_scale,
|
|
'dense_act': dense_act_scale,
|
|
'dense_weights': dense_weights_scale,
|
|
'fc_act': fc_act_scale,
|
|
'fc_weights': fc_weights_scale,
|
|
'proj_act': proj_act_scale,
|
|
'proj_weights': proj_weights_scale,
|
|
}
|
|
"""
|
|
|
|
if model_path is None:
|
|
logger.warning(f"--quantized_fp8_model_path not specified. "
|
|
f"Initialize quantization scales automatically.")
|
|
return get_dummy_quant_scales(num_layers)
|
|
weight_dict = np.load(model_path)
|
|
|
|
# yapf: disable
|
|
scaling_factor = {
|
|
'qkv_act': [],
|
|
'qkv_weights': [],
|
|
'qkv_output': [],
|
|
'dense_act': [],
|
|
'dense_weights': [],
|
|
'fc_act': [],
|
|
'fc_weights': [],
|
|
'proj_act': [],
|
|
'proj_weights': [],
|
|
}
|
|
|
|
for layer in range(num_layers):
|
|
scaling_factor['qkv_act'].append(max(
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
|
|
))
|
|
scaling_factor['qkv_weights'].append(max(
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
|
|
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
|
|
))
|
|
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
|
|
# Not calibrarting KV cache.
|
|
scaling_factor['qkv_output'].append(1.0)
|
|
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
|
|
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
|
|
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
|
|
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
|
|
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
|
|
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
|
|
# yapf: enable
|
|
for k, v in scaling_factor.items():
|
|
assert len(v) == num_layers, \
|
|
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
|
|
|
|
return scaling_factor
|
|
|
|
|
|
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
|
|
suffix = f"{rank}.bin"
|
|
if use_smooth_quant:
|
|
sq_prefix = "int8."
|
|
if quant_per_channel:
|
|
sq_prefix += "col."
|
|
suffix = sq_prefix + suffix
|
|
return suffix
|
|
|
|
|
|
def extract_layer_idx(name):
|
|
ss = name.split('.')
|
|
for s in ss:
|
|
if s.isdigit():
|
|
return s
|
|
return None
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0):
|
|
if tp_size == 1:
|
|
return v
|
|
if len(v.shape) == 1:
|
|
return np.ascontiguousarray(np.split(v, tp_size)[idx].copy())
|
|
elif len(v.shape) == 2:
|
|
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy())
|
|
return None
|
|
|
|
|
|
def parse_ft_config(ini_file):
|
|
gpt_config = configparser.ConfigParser()
|
|
gpt_config.read(ini_file)
|
|
|
|
n_embd = gpt_config.getint('gpt', 'n_embd')
|
|
n_head = gpt_config.getint('gpt', 'n_head')
|
|
n_layer = gpt_config.getint('gpt', 'n_layer')
|
|
n_positions = gpt_config.getint('gpt', 'n_positions')
|
|
vocab_size = gpt_config.getint('gpt', 'vocab_size')
|
|
do_layer_norm_before = gpt_config.getboolean('gpt',
|
|
'do_layer_norm_before',
|
|
fallback=True)
|
|
rotary_pct = gpt_config.getfloat('gpt', 'rotary_pct', fallback=0.0)
|
|
hidden_act = gpt_config.get('gpt', 'activation_function')
|
|
bias = gpt_config.getboolean('gpt', 'bias', fallback=True)
|
|
inter_size = gpt_config.getint('gpt', 'intermediate_size', fallback=None)
|
|
dtype = gpt_config.get('gpt', 'storage_dtype', fallback='float32')
|
|
|
|
if inter_size is None:
|
|
inter_size = 4 * n_embd
|
|
n_kv_head = gpt_config.getint('gpt', 'n_kv_head', fallback=None)
|
|
|
|
multi_query_mode = gpt_config.getboolean('gpt',
|
|
'multi_query_mode',
|
|
fallback=False)
|
|
assert not (multi_query_mode and n_kv_head and n_kv_head != 1), \
|
|
"if multi_query_mode is enabled, n_kv_head must be 1 or unset"
|
|
if multi_query_mode:
|
|
n_kv_head = 1
|
|
prompt_num_tasks = gpt_config.getint('gpt', 'prompt_num_tasks', fallback=0)
|
|
prompt_max_vocab_size = gpt_config.getint('gpt',
|
|
'prompt_max_vocab_size',
|
|
fallback=0)
|
|
pos_embedding_type = gpt_config.get('gpt',
|
|
'position_embedding_type',
|
|
fallback='alibi')
|
|
return n_embd, n_head, n_layer, n_positions, vocab_size, do_layer_norm_before, hidden_act, rotary_pct, bias, inter_size, n_kv_head, dtype, prompt_num_tasks, prompt_max_vocab_size, pos_embedding_type
|
|
|
|
|
|
def check_embedding_share(dir_path):
|
|
share_embedding_table = False
|
|
lm_file = dir_path + '/' + 'model.lm_head.weight.bin'
|
|
if not Path(lm_file).exists():
|
|
share_embedding_table = True
|
|
return share_embedding_table
|
|
|
|
|
|
def load_from_ft(tensorrt_llm_gpt: GPTLMHeadModel,
|
|
dir_path,
|
|
rank=0,
|
|
tensor_parallel=1,
|
|
dtype='float32',
|
|
use_parallel_embedding=False,
|
|
sharding_dim=0,
|
|
share_embedding_table=False):
|
|
tensorrt_llm.logger.info('Loading weights from FT...')
|
|
tik = time.time()
|
|
|
|
quant_mode = getattr(tensorrt_llm_gpt, 'quant_mode', QuantMode(0))
|
|
if quant_mode.is_int8_weight_only():
|
|
plugin_weight_only_quant_type = torch.int8
|
|
elif quant_mode.is_int4_weight_only():
|
|
plugin_weight_only_quant_type = torch.quint4x2
|
|
n_embd, n_head, n_layer, n_positions, vocab_size, do_layer_norm_before, hidden_act, rotary_pct, bias, inter_size, n_kv_head, *_ = parse_ft_config(
|
|
Path(dir_path) / 'config.ini')
|
|
np_dtype = str_dtype_to_np(dtype)
|
|
|
|
def fromfile(dir_path, name, shape=None, dtype=None):
|
|
dtype = np_dtype if dtype is None else dtype
|
|
p = dir_path + '/' + name
|
|
if Path(p).exists():
|
|
t = np.fromfile(p, dtype=dtype)
|
|
if shape is not None:
|
|
t = t.reshape(shape)
|
|
return t
|
|
return None
|
|
|
|
def set_smoothquant_scale_factors(module,
|
|
pre_scale_weight,
|
|
dir_path,
|
|
basename,
|
|
shape,
|
|
per_tok_dyn,
|
|
per_channel,
|
|
is_qkv=False,
|
|
rank=None):
|
|
suffix = "bin"
|
|
if per_channel:
|
|
if rank is not None:
|
|
suffix = f"{rank}." + suffix
|
|
suffix = "col." + suffix
|
|
|
|
col_shape = shape if (per_channel or is_qkv) else [1, 1]
|
|
if per_tok_dyn:
|
|
if pre_scale_weight is not None:
|
|
pre_scale_weight.value = np.array([1.0], dtype=np.float32)
|
|
if is_qkv and not per_channel:
|
|
t = fromfile(dir_path,
|
|
f"{basename}scale_w_quant_orig.{rank}.{suffix}",
|
|
col_shape, np.float32)
|
|
else:
|
|
t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}",
|
|
col_shape, np.float32)
|
|
module.per_channel_scale.value = t
|
|
else:
|
|
t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1],
|
|
np.float32)
|
|
pre_scale_weight.value = t
|
|
if is_qkv:
|
|
t = fromfile(dir_path,
|
|
f"{basename}scale_y_accum_quant.{rank}.{suffix}",
|
|
col_shape, np.float32)
|
|
else:
|
|
t = fromfile(dir_path,
|
|
f"{basename}scale_y_accum_quant.{suffix}",
|
|
col_shape, np.float32)
|
|
module.per_channel_scale.value = t
|
|
t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1],
|
|
np.float32)
|
|
module.act_scale.value = t
|
|
|
|
def set_smoother(module, dir_path, base_name, shape, rank):
|
|
suffix = f"{rank}.bin"
|
|
t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape,
|
|
np.float32)
|
|
module.smoother.value = t
|
|
|
|
# Determine the quantization mode.
|
|
quant_mode = getattr(tensorrt_llm_gpt, "quant_mode", QuantMode(0))
|
|
# Do we use SmoothQuant?
|
|
use_smooth_quant = quant_mode.has_act_and_weight_quant()
|
|
# Do we use quantization per token?
|
|
quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
|
|
# Do we use quantization per channel?
|
|
quant_per_channel = quant_mode.has_per_channel_scaling()
|
|
|
|
# Do we use INT4/INT8 weight-only?
|
|
use_weight_only = quant_mode.is_weight_only()
|
|
|
|
# Int8 KV cache
|
|
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
|
|
|
|
# Debug
|
|
suffix = gen_suffix(rank, use_smooth_quant, quant_per_channel)
|
|
# The type of weights.
|
|
w_type = np_dtype if not use_smooth_quant else np.int8
|
|
|
|
pe = fromfile(dir_path, 'model.wpe.bin', [n_positions, n_embd])
|
|
if pe is not None:
|
|
tensorrt_llm_gpt.embedding.position_embedding.weight.value = (pe)
|
|
|
|
vocab_embedding_weight = fromfile(dir_path, 'model.wte.bin',
|
|
[vocab_size, n_embd])
|
|
if not use_parallel_embedding:
|
|
tensorrt_llm_gpt.embedding.vocab_embedding.weight.value = vocab_embedding_weight
|
|
else:
|
|
if sharding_dim == 0:
|
|
if vocab_size % tensor_parallel != 0:
|
|
# padding
|
|
vocab_size_padded = pad_vocab_size(
|
|
tensorrt_llm_gpt.embedding.vocab_embedding.num_embeddings,
|
|
tensor_parallel)
|
|
pad_width = vocab_size_padded - vocab_size
|
|
vocab_embedding_weight = np.pad(vocab_embedding_weight,
|
|
((0, pad_width), (0, 0)),
|
|
'constant',
|
|
constant_values=0)
|
|
tensorrt_llm_gpt.embedding.vocab_embedding.weight.value = np.ascontiguousarray(
|
|
split(vocab_embedding_weight,
|
|
tensor_parallel,
|
|
rank,
|
|
dim=sharding_dim))
|
|
|
|
if do_layer_norm_before:
|
|
tensorrt_llm_gpt.ln_f.bias.value = (fromfile(
|
|
dir_path, 'model.final_layernorm.bias.bin'))
|
|
tensorrt_llm_gpt.ln_f.weight.value = (fromfile(
|
|
dir_path, 'model.final_layernorm.weight.bin'))
|
|
|
|
# share input embedding
|
|
if not share_embedding_table:
|
|
lm_head_weight = fromfile(dir_path, 'model.lm_head.weight.bin',
|
|
[vocab_size, n_embd])
|
|
if lm_head_weight is None:
|
|
lm_head_weight = fromfile(dir_path, 'model.wte.bin',
|
|
[vocab_size, n_embd])
|
|
if vocab_size % tensor_parallel != 0:
|
|
# padding
|
|
vocab_size_padded = tensorrt_llm_gpt.lm_head.out_features * tensor_parallel
|
|
pad_width = vocab_size_padded - vocab_size
|
|
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)),
|
|
'constant',
|
|
constant_values=0)
|
|
tensorrt_llm_gpt.lm_head.weight.value = np.ascontiguousarray(
|
|
split(lm_head_weight, tensor_parallel, rank))
|
|
for i in range(n_layer):
|
|
head_dim = n_embd // n_head
|
|
if n_kv_head == 1:
|
|
# multi-query attention.
|
|
c_attn_out_dim = (n_embd // tensor_parallel) + (head_dim * 2)
|
|
elif n_kv_head:
|
|
# grouped-query attention.
|
|
c_attn_out_dim = (n_embd // tensor_parallel +
|
|
(head_dim * n_kv_head * 2) // tensor_parallel)
|
|
else:
|
|
# multi-head attention.
|
|
c_attn_out_dim = 3 * n_embd // tensor_parallel
|
|
tensorrt_llm_gpt.layers[i].input_layernorm.weight.value = (fromfile(
|
|
dir_path, 'model.layers.' + str(i) + '.input_layernorm.weight.bin'))
|
|
tensorrt_llm_gpt.layers[i].input_layernorm.bias.value = (fromfile(
|
|
dir_path, 'model.layers.' + str(i) + '.input_layernorm.bias.bin'))
|
|
t = fromfile(
|
|
dir_path, 'model.layers.' + str(i) +
|
|
'.attention.query_key_value.weight.' + suffix,
|
|
[n_embd, c_attn_out_dim], w_type)
|
|
if t is not None:
|
|
dst = tensorrt_llm_gpt.layers[i].attention.qkv.weight
|
|
if use_smooth_quant:
|
|
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
|
set_smoothquant_scale_factors(
|
|
tensorrt_llm_gpt.layers[i].attention.qkv,
|
|
tensorrt_llm_gpt.layers[i].input_layernorm.scale_to_int,
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.attention.query_key_value.',
|
|
[1, c_attn_out_dim],
|
|
quant_per_token_dyn,
|
|
quant_per_channel,
|
|
rank=rank,
|
|
is_qkv=True)
|
|
elif use_weight_only:
|
|
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
|
torch.tensor(t), plugin_weight_only_quant_type)
|
|
dst.value = processed_torch_weights.numpy()
|
|
scales = tensorrt_llm_gpt.layers[
|
|
i].attention.qkv.per_channel_scale
|
|
scales.value = torch_weight_scales.numpy()
|
|
else:
|
|
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
|
if bias:
|
|
t = fromfile(
|
|
dir_path, 'model.layers.' + str(i) +
|
|
'.attention.query_key_value.bias.' + str(rank) + '.bin')
|
|
if t is not None:
|
|
dst = tensorrt_llm_gpt.layers[i].attention.qkv.bias
|
|
dst.value = np.ascontiguousarray(t)
|
|
|
|
dst = tensorrt_llm_gpt.layers[i].attention.dense.weight
|
|
t = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.attention.dense.weight.' + suffix,
|
|
[n_embd // tensor_parallel, n_embd], w_type)
|
|
if use_smooth_quant:
|
|
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
|
dense_scale = getattr(tensorrt_llm_gpt.layers[i].attention,
|
|
"quantization_scaling_factor", None)
|
|
set_smoothquant_scale_factors(
|
|
tensorrt_llm_gpt.layers[i].attention.dense, dense_scale,
|
|
dir_path, 'model.layers.' + str(i) + '.attention.dense.',
|
|
[1, n_embd], quant_per_token_dyn, quant_per_channel)
|
|
# change it to the real smoother if dense layer is applied smooth quant
|
|
# tensorrt_llm_gpt.layers[i].attention.dense.smoother.value = np.ones(
|
|
# [1, n_embd // tensor_parallel], dtype=np.float32)
|
|
set_smoother(tensorrt_llm_gpt.layers[i].attention.dense, dir_path,
|
|
'model.layers.' + str(i) + '.attention.dense',
|
|
[1, n_embd // tensor_parallel], rank)
|
|
elif use_weight_only:
|
|
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
|
torch.tensor(t), plugin_weight_only_quant_type)
|
|
dst.value = processed_torch_weights.numpy()
|
|
scales = tensorrt_llm_gpt.layers[
|
|
i].attention.dense.per_channel_scale
|
|
scales.value = torch_weight_scales.numpy()
|
|
else:
|
|
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
|
|
|
if bias:
|
|
dst = tensorrt_llm_gpt.layers[i].attention.dense.bias
|
|
dst.value = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.attention.dense.bias.bin')
|
|
|
|
dst = tensorrt_llm_gpt.layers[i].post_layernorm.weight
|
|
dst.value = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.post_attention_layernorm.weight.bin')
|
|
|
|
dst = tensorrt_llm_gpt.layers[i].post_layernorm.bias
|
|
dst.value = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.post_attention_layernorm.bias.bin')
|
|
t = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.mlp.dense_h_to_4h.weight.' + suffix,
|
|
[n_embd, inter_size // tensor_parallel], w_type)
|
|
if use_smooth_quant:
|
|
tensorrt_llm_gpt.layers[
|
|
i].mlp.fc.weight.value = np.ascontiguousarray(
|
|
np.transpose(t, [1, 0]))
|
|
set_smoothquant_scale_factors(
|
|
tensorrt_llm_gpt.layers[i].mlp.fc,
|
|
tensorrt_llm_gpt.layers[i].post_layernorm.scale_to_int,
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.mlp.dense_h_to_4h.',
|
|
[1, inter_size // tensor_parallel],
|
|
quant_per_token_dyn,
|
|
quant_per_channel,
|
|
rank=rank)
|
|
elif use_weight_only:
|
|
dst = tensorrt_llm_gpt.layers[i].mlp.fc.weight
|
|
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
|
torch.tensor(t), plugin_weight_only_quant_type)
|
|
dst.value = processed_torch_weights.numpy()
|
|
scales = tensorrt_llm_gpt.layers[i].mlp.fc.per_channel_scale
|
|
scales.value = torch_weight_scales.numpy()
|
|
else:
|
|
tensorrt_llm_gpt.layers[
|
|
i].mlp.fc.weight.value = np.ascontiguousarray(
|
|
np.transpose(t, [1, 0]))
|
|
if bias:
|
|
tensorrt_llm_gpt.layers[i].mlp.fc.bias.value = fromfile(
|
|
dir_path, 'model.layers.' + str(i) +
|
|
'.mlp.dense_h_to_4h.bias.' + str(rank) + '.bin')
|
|
if is_gated_activation(hidden_act):
|
|
t = fromfile(
|
|
dir_path, 'model.layers.' + str(i) +
|
|
'.mlp.dense_h_to_4h.gate.weight.' + str(rank) + '.bin',
|
|
[n_embd, inter_size // tensor_parallel])
|
|
tensorrt_llm_gpt.layers[
|
|
i].mlp.gate.weight.value = np.ascontiguousarray(
|
|
np.transpose(t, [1, 0]))
|
|
|
|
t = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.mlp.dense_4h_to_h.weight.' + suffix,
|
|
[inter_size // tensor_parallel, n_embd], w_type)
|
|
if use_smooth_quant:
|
|
tensorrt_llm_gpt.layers[
|
|
i].mlp.proj.weight.value = np.ascontiguousarray(
|
|
np.transpose(t, [1, 0]))
|
|
proj_scale = getattr(tensorrt_llm_gpt.layers[i].mlp,
|
|
"quantization_scaling_factor", None)
|
|
set_smoothquant_scale_factors(
|
|
tensorrt_llm_gpt.layers[i].mlp.proj, proj_scale, dir_path,
|
|
'model.layers.' + str(i) + '.mlp.dense_4h_to_h.', [1, n_embd],
|
|
quant_per_token_dyn, quant_per_channel)
|
|
# change it to the real smoother if proj layer is applied smooth quant
|
|
# tensorrt_llm_gpt.layers[i].mlp.proj.smoother.value = np.ones(
|
|
# [1, inter_size // tensor_parallel], dtype=np.float32)
|
|
set_smoother(tensorrt_llm_gpt.layers[i].mlp.proj, dir_path,
|
|
'model.layers.' + str(i) + '.mlp.dense_4h_to_h',
|
|
[1, inter_size // tensor_parallel], rank)
|
|
elif use_weight_only:
|
|
dst = tensorrt_llm_gpt.layers[i].mlp.proj.weight
|
|
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
|
torch.tensor(t), plugin_weight_only_quant_type)
|
|
dst.value = processed_torch_weights.numpy()
|
|
scales = tensorrt_llm_gpt.layers[i].mlp.proj.per_channel_scale
|
|
scales.value = torch_weight_scales.numpy()
|
|
else:
|
|
tensorrt_llm_gpt.layers[i].mlp.proj.weight.value = (
|
|
np.ascontiguousarray(np.transpose(t, [1, 0])))
|
|
if bias:
|
|
tensorrt_llm_gpt.layers[i].mlp.proj.bias.value = fromfile(
|
|
dir_path,
|
|
'model.layers.' + str(i) + '.mlp.dense_4h_to_h.bias.bin')
|
|
|
|
if use_int8_kv_cache:
|
|
t = fromfile(
|
|
dir_path, 'model.layers.' + str(i) +
|
|
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
|
|
np.float32)
|
|
tensorrt_llm_gpt.layers[
|
|
i].attention.kv_orig_quant_scale.value = 1.0 / t
|
|
tensorrt_llm_gpt.layers[i].attention.kv_quant_orig_scale.value = t
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
|
|
|
|
|
def load_from_awq_mpt(tensorrt_llm_mpt: GPTLMHeadModel,
|
|
quant_ckpt_path,
|
|
mapping=Mapping(),
|
|
dtype="float16",
|
|
ft_model_dir=None):
|
|
tensorrt_llm.logger.info(
|
|
'Loading weights from groupwise AWQ mpt checkpoint...')
|
|
tik = time.time()
|
|
|
|
awq_mpt = np.load(quant_ckpt_path)
|
|
awq_prefix = "_np:"
|
|
awq_suffix_list = [
|
|
":weight",
|
|
":weights_scaling_factor",
|
|
":prequant_scaling_factor",
|
|
]
|
|
awq_key_list = [
|
|
"vocab_embedding:weight", # vocab_embedding lm_head
|
|
"final_layernorm:weight", # ln_f
|
|
"attention:qkv:", # attention.qkv
|
|
"attention:dense", # attention.dense
|
|
"mlp:fc", # mlp.fc
|
|
"mlp:proj", # mlp.proj
|
|
"input_layernorm:weight", # input_layernorm
|
|
"post_layernorm:weight", # post_layernorm
|
|
]
|
|
split_sym = ":"
|
|
AMMO_WEIGHT_SCALING_FACTOR_COEFF = 7
|
|
|
|
def load(key):
|
|
v = torch.from_numpy(awq_mpt[awq_prefix + key])
|
|
if "weights_scaling_factor" in key:
|
|
v *= AMMO_WEIGHT_SCALING_FACTOR_COEFF # For AMMO *.npz checkpoints
|
|
return v
|
|
|
|
group_size = load("layers:0:attention:dense:weight").numel() // load(
|
|
"layers:0:attention:dense:weights_scaling_factor").numel()
|
|
|
|
quant_mode = getattr(tensorrt_llm_mpt, 'quant_mode', QuantMode(0))
|
|
# Int8 KV cache
|
|
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
|
|
|
|
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
|
|
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
|
|
def fromfile(dir_path, name, shape=None, dtype=None):
|
|
p = dir_path + '/' + name
|
|
if Path(p).exists():
|
|
t = np.fromfile(p, dtype=dtype)
|
|
if shape is not None:
|
|
t = t.reshape(shape)
|
|
return t
|
|
return None
|
|
|
|
def torch_split(v, dim):
|
|
if v.shape[dim] % mapping.tp_size != 0:
|
|
tensorrt_llm.logger.error(
|
|
"Current weight shape is invalid for mapping.tp_size=" +
|
|
str(mapping.tp_size))
|
|
assert False, "Invalid TP size"
|
|
return v.split(v.shape[dim] // mapping.tp_size,
|
|
dim=dim)[mapping.tp_rank]
|
|
|
|
def AWQ_quantize_pack_preprocess(weight, scale):
|
|
weight /= scale.repeat_interleave(group_size, dim=0)
|
|
qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7)
|
|
int4_weight = preprocessor(packer(qweight_int8.cpu()), torch.quint4x2)
|
|
return int4_weight.view(torch.int8).cpu().numpy()
|
|
|
|
def process_and_assign_weight(mOp, v, tp_dim=0):
|
|
weight = v[0].T.contiguous()
|
|
[k, n] = weight.shape
|
|
weight = torch_split(weight, tp_dim)
|
|
amax = v[1].reshape((n, k // group_size)).T.contiguous()
|
|
amax = torch_split(amax, tp_dim)
|
|
pre_quant_scale = v[2].reshape((1, k))
|
|
if tp_dim == 0:
|
|
pre_quant_scale = torch_split(pre_quant_scale, 1)
|
|
scale = amax / 8.0
|
|
mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale)
|
|
mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
|
|
mOp.pre_quant_scale.value = pre_quant_scale.to(
|
|
torch_dtype).cpu().numpy()
|
|
|
|
def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale):
|
|
# deSmooth and reSmooth
|
|
[k, n] = weight.shape
|
|
if quant_ckpt_path.endswith("pt"):
|
|
# NPZ files are already re-smoothed
|
|
weight *= pre_quant_scale.repeat((n, 1)).transpose(1,
|
|
0).contiguous()
|
|
weight /= avg_pre_quant_scale.repeat(
|
|
(n, 1)).transpose(1, 0).contiguous()
|
|
|
|
# Get scale
|
|
weight_t = weight.T.contiguous()
|
|
weight_t = weight_t.reshape(n, k // group_size, group_size)
|
|
weight_t = torch.abs(weight_t.reshape(-1, group_size))
|
|
amax, idx = weight_t.max(1)
|
|
amax = amax.reshape(n, k // group_size).T.contiguous()
|
|
scale = amax / 8
|
|
return weight, scale
|
|
|
|
def process_and_assign_qkv_weight(prefix, mOp):
|
|
q_weight = load(prefix + "q" + awq_suffix_list[0]).T.contiguous()
|
|
k_weight = load(prefix + "k" + awq_suffix_list[0]).T.contiguous()
|
|
v_weight = load(prefix + "v" + awq_suffix_list[0]).T.contiguous()
|
|
dim_k = q_weight.shape[0]
|
|
q_weight = torch_split(q_weight, 1)
|
|
k_weight = torch_split(k_weight, 1)
|
|
v_weight = torch_split(v_weight, 1)
|
|
q_pre_quant_scale = load(prefix + "q" + awq_suffix_list[2]).reshape(
|
|
(1, dim_k))
|
|
k_pre_quant_scale = load(prefix + "k" + awq_suffix_list[2]).reshape(
|
|
(1, dim_k))
|
|
v_pre_quant_scale = load(prefix + "v" + awq_suffix_list[2]).reshape(
|
|
(1, dim_k))
|
|
qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale +
|
|
v_pre_quant_scale) / 3.0
|
|
q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale,
|
|
qkv_pre_quant_scale)
|
|
k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale,
|
|
qkv_pre_quant_scale)
|
|
v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale,
|
|
qkv_pre_quant_scale)
|
|
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
|
|
qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
|
|
|
|
mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(
|
|
torch_dtype).cpu().numpy()
|
|
mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale)
|
|
mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
|
|
|
|
# Load weights from AWQ checkpoint into TRT-LLM module
|
|
# 1. vocab_embedding and lm_head
|
|
v = load(awq_key_list[0])
|
|
# TRT-LLM requires vocab_size to be multiple of 64 for successful GEMM
|
|
if v.shape[0] % 64 != 0:
|
|
v = torch.nn.functional.pad(v, [0, 0, 0, 64 - v.shape[0] % 64])
|
|
if mapping.is_first_pp_rank():
|
|
tensorrt_llm_mpt.embedding.vocab_embedding.weight.value = v.to(
|
|
torch_dtype).cpu().numpy()
|
|
if mapping.is_last_pp_rank():
|
|
tp_dim = 0
|
|
v = torch_split(v.clone(), tp_dim)
|
|
tensorrt_llm_mpt.lm_head.weight.value = v
|
|
|
|
# 2. ln_f
|
|
v = load(awq_key_list[1])
|
|
if mapping.is_last_pp_rank():
|
|
tensorrt_llm_mpt.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
# MPT do not have LN bias, we set 0 here.
|
|
random_bias = tensorrt_llm_mpt.ln_f.weight._value
|
|
tensorrt_llm_mpt.ln_f.bias.value = np.zeros(random_bias.shape).astype(
|
|
random_bias.dtype)
|
|
|
|
# 3. Weights inside each layer
|
|
num_hidden_layers = tensorrt_llm_mpt._num_layers
|
|
layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
|
|
layers_range = list(
|
|
range(mapping.pp_rank * layers_per_pipeline_stage,
|
|
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
|
|
|
|
for l in layers_range:
|
|
layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage
|
|
prefix = "layers" + split_sym + str(layer_idx) + split_sym
|
|
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
|
|
layer = tensorrt_llm_mpt.layers[layer_idx]
|
|
|
|
# 4.1 attention.qkv
|
|
process_and_assign_qkv_weight(prefix + awq_key_list[2],
|
|
layer.attention.qkv)
|
|
|
|
# 4.2 attention.dense
|
|
v = [load(prefix + awq_key_list[3] + suf) for suf in awq_suffix_list]
|
|
process_and_assign_weight(layer.attention.dense, v, 0)
|
|
|
|
# 4.3 mlp.fc
|
|
v = [load(prefix + awq_key_list[4] + suf) for suf in awq_suffix_list]
|
|
process_and_assign_weight(layer.mlp.fc, v, 1)
|
|
|
|
# 4.4 mlp.proj
|
|
v = [load(prefix + awq_key_list[5] + suf) for suf in awq_suffix_list]
|
|
process_and_assign_weight(layer.mlp.proj, v, 0)
|
|
|
|
# 4.5 input_layernorm
|
|
v = load(prefix + awq_key_list[6])
|
|
layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
random_bias = layer.input_layernorm.bias._value
|
|
layer.input_layernorm.bias.value = np.zeros(random_bias.shape).astype(
|
|
random_bias.dtype)
|
|
|
|
# 4.6 post_layernorm
|
|
v = load(prefix + awq_key_list[7])
|
|
layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy()
|
|
random_bias = layer.post_layernorm.bias._value
|
|
layer.post_layernorm.bias.value = np.zeros(random_bias.shape).astype(
|
|
random_bias.dtype)
|
|
|
|
# 4.7 attention.kv_quant_orig_scale / kv_quant_orig_scale
|
|
if use_int8_kv_cache:
|
|
assert ft_model_dir, "You must pass --ft_model_dir to tell TRT-LLM where to look for scales of INT8 kv cache."
|
|
t = fromfile(
|
|
ft_model_dir, 'model.layers.' + str(layer_idx) +
|
|
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
|
|
np.float32)
|
|
assert t is not None, f"{ft_model_dir} does not contain model.layers.{layer_idx}.attention.query_key_value.scale_y_quant_orig.bin"
|
|
layer.attention.kv_orig_quant_scale.value = 1.0 / t
|
|
layer.attention.kv_quant_orig_scale.value = t
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|