TensorRT-LLMs/examples/internlm/weight.py
Kaiyu Xie f044eb8d94
Update TensorRT-LLM (#302)
* Update TensorRT-LLM

---------

Co-authored-by: wangruohui <12756472+wangruohui@users.noreply.github.com>
2023-11-07 19:51:58 +08:00

1319 lines
60 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import configparser
import math
import time
from operator import attrgetter
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from safetensors import safe_open
import tensorrt_llm
import tensorrt_llm.logger as logger
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import InternLMForCausalLM
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
from tensorrt_llm.quantization import QuantMode
def get_scaling_factors(
model_path: Union[str, Path],
num_layers: int,
quant_mode: Optional[QuantMode] = None,
) -> Optional[Dict[str, List[int]]]:
""" Get the scaling factors for InternLM model
Returns a dictionary of scaling factors for the selected layers of the
InternLM model.
Args:
model_path (str): Path to the quantized InternLM model
layers (list): List of layers to get the scaling factors for. If None,
all layers are selected.
Returns:
dict: Dictionary of scaling factors for the selected layers of the
InternLM model.
example:
{
'qkv_act': qkv_act_scale,
'qkv_weights': qkv_weights_scale,
'qkv_output' : qkv_outputs_scale,
'dense_act': dense_act_scale,
'dense_weights': dense_weights_scale,
'fc_act': fc_act_scale,
'fc_weights': fc_weights_scale,
'gate_act': gate_act_scale,
'gate_weights': gate_weights_scale,
'proj_act': proj_act_scale,
'proj_weights': proj_weights_scale,
}
"""
if model_path is None:
logger.warning(f"--quantized_fp8_model_path not specified. "
f"Initialize quantization scales automatically.")
return get_dummy_quant_scales(num_layers)
weight_dict = np.load(model_path)
# yapf: disable
scaling_factor = {
'qkv_act': [],
'qkv_weights': [],
'qkv_output': [],
'dense_act': [],
'dense_weights': [],
'fc_act': [],
'fc_weights': [],
'gate_act': [],
'gate_weights': [],
'proj_act': [],
'proj_weights': [],
}
for layer in range(num_layers):
scaling_factor['qkv_act'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
))
scaling_factor['qkv_weights'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
))
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
# Not calibrarting KV cache.
scaling_factor['qkv_output'].append(1.0)
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
# yapf: enable
for k, v in scaling_factor.items():
assert len(v) == num_layers, \
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
return scaling_factor
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
suffix = f"{rank}.bin"
if use_smooth_quant:
sq_prefix = "int8."
if quant_per_channel:
sq_prefix += "col."
suffix = sq_prefix + suffix
return suffix
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return s
return None
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return np.ascontiguousarray(np.split(v, tp_size)[idx].copy())
else:
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy())
def dup_kv_weight(v, num_head, tp_size):
assert tp_size % num_head == 0
reps = tp_size // num_head
head_size = v.shape[0] // num_head
v = v.reshape(num_head, head_size,
-1)[:, None, :, :].expand(num_head, reps, head_size,
v.shape[1])
return v.reshape(num_head * reps * head_size, -1).clone()
def parse_ft_config(ini_file):
gpt_config = configparser.ConfigParser()
gpt_config.read(ini_file)
n_embd = gpt_config.getint('internlm', 'hidden_size')
n_head = gpt_config.getint('internlm', 'num_attention_heads')
n_layer = gpt_config.getint('internlm', 'num_hidden_layers')
n_positions = gpt_config.getint('internlm', 'max_position_embeddings')
vocab_size = gpt_config.getint('internlm', 'vocab_size')
hidden_act = gpt_config.get('internlm', 'hidden_act')
inter_size = gpt_config.getint('internlm',
'intermediate_size',
fallback=None)
n_kv_head = gpt_config.getint('internlm',
'num_key_value_heads',
fallback=n_head)
attn_bias = gpt_config.getboolean('internlm', 'bias', fallback=False)
if inter_size is None:
inter_size = math.ceil(8 / 3 * n_embd / 256) * 256
return n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head, attn_bias
def load_from_hf_internlm(
tensorrt_llm_internlm: tensorrt_llm.models.InternLMForCausalLM,
hf_internlm,
mapping=Mapping(),
dtype='float32'):
tensorrt_llm.logger.info('Loading weights from HF InternLM...')
tik = time.time()
quant_mode = getattr(tensorrt_llm_internlm, 'quant_mode', QuantMode(0))
if quant_mode.is_int8_weight_only():
plugin_weight_only_quant_type = torch.int8
elif quant_mode.is_int4_weight_only():
plugin_weight_only_quant_type = torch.quint4x2
use_weight_only = quant_mode.is_weight_only()
num_kv_heads = tensorrt_llm_internlm.num_kv_heads
mha_mode = (num_kv_heads == tensorrt_llm_internlm.num_heads)
assert mha_mode, "All InternLM variants should be MHA mode"
model_params = dict(hf_internlm.named_parameters())
for l in range(hf_internlm.config.num_hidden_layers):
prefix = f'model.layers.{l}.self_attn.'
q_weight = model_params[prefix + 'q_proj.weight']
k_weight = model_params[prefix + 'k_proj.weight']
v_weight = model_params[prefix + 'v_proj.weight']
if not mha_mode:
head_size = tensorrt_llm_internlm.hidden_size // tensorrt_llm_internlm.num_heads
if num_kv_heads < mapping.tp_size:
# duplicate the KV heads up to tensor_parallel
k_weight = dup_kv_weight(k_weight, num_kv_heads,
mapping.tp_size)
v_weight = dup_kv_weight(v_weight, num_kv_heads,
mapping.tp_size)
assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
qkv_weight = [q_weight, k_weight, v_weight]
else:
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
model_params[prefix + 'qkv_proj.weight'] = qkv_weight
if prefix + 'q_proj.bias' in model_params:
# only used in 7B models
# assert not mha_mode, "MHA mode not used in internlm 7B models"
q_bias = model_params[prefix + 'q_proj.bias']
k_bias = model_params[prefix + 'k_proj.bias']
v_bias = model_params[prefix + 'v_proj.bias']
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0)
model_params[prefix + 'qkv_proj.bias'] = qkv_bias
torch_dtype = str_dtype_to_torch(dtype)
layers_per_pipeline_stage = hf_internlm.config.num_hidden_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
for k, v in model_params.items():
if isinstance(v, list):
v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
else:
v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
if 'model.embed_tokens.weight' in k:
if tensorrt_llm_internlm.use_parallel_embedding:
v = split(v, mapping.tp_size, mapping.tp_rank,
tensorrt_llm_internlm.embedding_sharding_dim)
if mapping.is_first_pp_rank():
tensorrt_llm_internlm.vocab_embedding.weight.value = v
elif 'model.norm.weight' in k:
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.ln_f.weight.value = v
elif 'lm_head.weight' in k:
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.lm_head.weight.value = np.ascontiguousarray(
split(v, mapping.tp_size, mapping.tp_rank))
else:
layer_idx = extract_layer_idx(k)
if layer_idx is None or int(layer_idx) not in layers_range:
continue
idx = int(layer_idx) - mapping.pp_rank * layers_per_pipeline_stage
if idx >= tensorrt_llm_internlm.num_layers:
continue
if 'input_layernorm.weight' in k:
tensorrt_llm_internlm.layers[
idx].input_layernorm.weight.value = v
elif 'post_attention_layernorm.weight' in k:
dst = tensorrt_llm_internlm.layers[idx].post_layernorm.weight
dst.value = v
elif 'self_attn.qkv_proj.weight' in k:
dst = tensorrt_llm_internlm.layers[idx].attention.qkv.weight
if not mha_mode:
assert isinstance(v, list) and len(v) == 3
wq = split(v[0], mapping.tp_size, mapping.tp_rank)
wk = split(v[1], mapping.tp_size, mapping.tp_rank)
wv = split(v[2], mapping.tp_size, mapping.tp_rank)
split_v = np.concatenate((wq, wk, wv))
else:
q_emb = v.shape[0] // 3
model_emb = v.shape[1]
v = v.reshape(3, q_emb, model_emb)
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
split_v = split_v.reshape(3 * (q_emb // mapping.tp_size),
model_emb)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[
idx].attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'self_attn.qkv_proj.bias' in k:
dst = tensorrt_llm_internlm.layers[idx].attention.qkv.bias
if not mha_mode:
assert isinstance(v, list) and len(v) == 3
bq = split(v[0], mapping.tp_size, mapping.tp_rank)
bk = split(v[1], mapping.tp_size, mapping.tp_rank)
bv = split(v[2], mapping.tp_size, mapping.tp_rank)
split_v = np.concatenate((bq, bk, bv))
else:
q_emb = v.shape[0] // 3
v = v.reshape(3, q_emb)
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
split_v = split_v.reshape(3 * (q_emb // mapping.tp_size))
dst.value = np.ascontiguousarray(split_v)
elif 'self_attn.o_proj.weight' in k:
dst = tensorrt_llm_internlm.layers[idx].attention.dense.weight
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[
idx].attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'self_attn.o_proj.bias' in k:
dst = tensorrt_llm_internlm.layers[idx].attention.dense.bias
split_v = v # no need to divide among ranks?
dst.value = np.ascontiguousarray(split_v)
elif 'mlp.up_proj.weight' in k:
dst = tensorrt_llm_internlm.layers[idx].mlp.gate.weight
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[
idx].mlp.gate.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'mlp.down_proj.weight' in k:
dst = tensorrt_llm_internlm.layers[idx].mlp.proj.weight
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[
idx].mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'mlp.gate_proj.weight' in k:
dst = tensorrt_llm_internlm.layers[idx].mlp.fc.weight
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[
idx].mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return
def load_from_meta_internlm(
tensorrt_llm_internlm: tensorrt_llm.models.InternLMForCausalLM,
meta_ckpt_dir,
mapping=Mapping(),
dtype="float32"):
torch_dtype = str_dtype_to_torch(dtype)
def gather_ckpts(ckpts):
gathered = {}
for k in ckpts[0]:
d = 0
if any([n in k for n in ["wo", "w2", "tok"]]):
d = 1
if "norm" in k or "rope" in k: # no TP
gathered[k] = ckpts[0][k].clone()
else:
gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone()
return gathered
def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank):
split_ckpt = {}
for k in ckpt:
d = 0
if any([n in k for n in ["wo", "w2", "tok"]]):
d = 1
if "norm" in k or "rope" in k: # no TP
split_ckpt[k] = ckpt[k].clone()
elif tensorrt_llm_internlm.num_kv_heads < mapping.tp_size and any(
[n in k for n in ["wk", "wv"]]):
assert mapping.tp_size % tensorrt_llm_internlm.num_kv_heads == 0
# special case: we need to duplicate KV head
tmp = dup_kv_weight(ckpt[k], tensorrt_llm_internlm.num_kv_heads,
mapping.tp_size)
split_ckpt[k] = torch.split(tmp,
tmp.shape[d] // ranks_per_ckpt,
dim=d)[ckpt_rank].clone()
else:
split_ckpt[k] = torch.split(ckpt[k],
ckpt[k].shape[d] // ranks_per_ckpt,
dim=d)[ckpt_rank].clone()
return split_ckpt
def get_current_weights(num_ckpts):
if num_ckpts > mapping.tp_size:
# combine ckpts
assert (num_ckpts % mapping.tp_size) == 0
nf = num_ckpts // mapping.tp_size
fs = nf * mapping.tp_rank
file_ids = list(range(fs, fs + nf))
ckpts = []
for f in file_ids:
ckpt = torch.load(Path(meta_ckpt_dir,
f"consolidated.{f:02d}.pth"),
map_location="cpu")
ckpts.append(ckpt)
return gather_ckpts(ckpts)
elif num_ckpts < mapping.tp_size:
# split ckpt
assert (mapping.tp_size % num_ckpts) == 0
ranks_per_ckpt = mapping.tp_size // num_ckpts
ckpt_fid = mapping.tp_rank // ranks_per_ckpt
ckpt_rank = mapping.tp_rank % ranks_per_ckpt
nH_per_ckpt = tensorrt_llm_internlm.num_heads // num_ckpts
assert (nH_per_ckpt % ranks_per_ckpt) == 0
ckpt = torch.load(Path(meta_ckpt_dir,
f"consolidated.{ckpt_fid:02d}.pth"),
map_location="cpu")
return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank)
# num_ckpts == tensor_parallel, 1:1 mapping from files to TP
return torch.load(Path(meta_ckpt_dir,
f"consolidated.{mapping.tp_rank:02d}.pth"),
map_location="cpu")
def permute(w, nH, d, dH):
# due to MQA's wk, nH*dH != d could be true
return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d)
if not hasattr(load_from_meta_internlm, "saved_embed"):
load_from_meta_internlm.saved_embed = None
def gather_embedding(cur_embed, name: str, num_ckpts):
if mapping.tp_size == 1:
# even if num_ckpts > 1, get_current_weights will already have it gathered
return cur_embed
if load_from_meta_internlm.saved_embed is None:
embeds = [None] * num_ckpts
for i in range(num_ckpts):
ckpt = torch.load(Path(meta_ckpt_dir,
f"consolidated.{i:02d}.pth"),
map_location="cpu")
embeds[i] = ckpt[name]
embed = torch.cat(embeds, dim=1).to(torch_dtype)
load_from_meta_internlm.saved_embed = torch_to_numpy(
embed) # cache the embedding, not needed if no refit
return load_from_meta_internlm.saved_embed
tensorrt_llm.logger.info(
'Loading weights from Meta InternLM checkpoints ...')
tik = time.time()
quant_mode = getattr(tensorrt_llm_internlm, 'quant_mode', QuantMode(0))
if quant_mode.is_int8_weight_only():
torch.int8
elif quant_mode.is_int4_weight_only():
torch.quint4x2
quant_mode.is_weight_only()
num_kv_heads = tensorrt_llm_internlm.num_kv_heads
mha_mode = (num_kv_heads == tensorrt_llm_internlm.num_heads)
ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth"))
num_ckpts = len(ckpts)
# internlm/internlm2 doesn't have MQA. So, simplifying loader logic by not worrying about it.
assert num_kv_heads > 1 or num_kv_heads >= num_ckpts, \
f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints."
head_size = tensorrt_llm_internlm.hidden_size // tensorrt_llm_internlm.num_heads
ckpt = get_current_weights(num_ckpts)
layers_range = list(
range(mapping.pp_rank * tensorrt_llm_internlm.num_layers,
(mapping.pp_rank + 1) * tensorrt_llm_internlm.num_layers, 1))
for l in layers_range:
prefix = f'layers.{l}.attention.'
q_weight = permute(ckpt[prefix + 'wq.weight'].clone(),
nH=(tensorrt_llm_internlm.num_heads //
mapping.tp_size),
d=tensorrt_llm_internlm.hidden_size,
dH=head_size)
if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size:
assert mapping.tp_size % num_kv_heads == 0
assert False, "Not supported yet"
k_weight = permute(ckpt[prefix + 'wk.weight'].clone(),
nH=((num_kv_heads + mapping.tp_size - 1) //
mapping.tp_size),
d=tensorrt_llm_internlm.hidden_size,
dH=head_size)
v_weight = ckpt[prefix + 'wv.weight'].clone()
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
ckpt[prefix + 'qkv.weight'] = qkv_weight
for k, v in ckpt.items():
v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
if "tok_embeddings" in k:
if not tensorrt_llm_internlm.use_parallel_embedding:
v = gather_embedding(v, k, num_ckpts)
elif tensorrt_llm_internlm.embedding_sharding_dim == 0:
# this needs a gather and then resplit along different dims
v = gather_embedding(v, k, num_ckpts)
v = split(v, mapping.tp_size, mapping.tp_rank, 0)
if mapping.is_first_pp_rank():
tensorrt_llm_internlm.vocab_embedding.weight.value = v
elif "output" in k:
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.lm_head.weight.value = v
elif k == "norm.weight":
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.ln_f.weight.value = v
else:
# layer specific weights
layer_idx = extract_layer_idx(k)
if layer_idx is None:
continue
idx = int(
layer_idx) - mapping.pp_rank * tensorrt_llm_internlm.num_layers
if idx >= tensorrt_llm_internlm.num_layers:
continue
if 'attention_norm.weight' in k:
tensorrt_llm_internlm.layers[
idx].input_layernorm.weight.value = v
elif 'ffn_norm.weight' in k:
tensorrt_llm_internlm.layers[
idx].post_layernorm.weight.value = v
elif 'feed_forward.w3.weight' in k:
tensorrt_llm_internlm.layers[idx].mlp.gate.weight.value = v
elif 'feed_forward.w2.weight' in k:
tensorrt_llm_internlm.layers[idx].mlp.proj.weight.value = v
elif 'feed_forward.w1.weight' in k:
tensorrt_llm_internlm.layers[idx].mlp.fc.weight.value = v
elif 'attention.wo.weight' in k:
tensorrt_llm_internlm.layers[
idx].attention.dense.weight.value = v
elif 'attention.qkv.weight' in k:
tensorrt_llm_internlm.layers[idx].attention.qkv.weight.value = v
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return
def load_from_binary(tensorrt_llm_internlm: InternLMForCausalLM,
dir_path,
mapping=Mapping(),
fp16=False,
multi_query_mode=False):
tensorrt_llm.logger.info('Loading weights from FT...')
tik = time.time()
quant_mode = getattr(tensorrt_llm_internlm, 'quant_mode', QuantMode(0))
n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head, attn_bias = parse_ft_config(
Path(dir_path) / 'config.ini')
np_dtype = np.float16 if fp16 else np.float32
def fromfile(dir_path, name, shape=None, dtype=None):
dtype = np_dtype if dtype is None else dtype
p = dir_path + '/' + name
print(f"Loading from {str(p)}")
if Path(p).exists():
t = np.fromfile(p, dtype=dtype)
if shape is not None:
t = t.reshape(shape)
return t
return None
def set_smoothquant_scale_factors(module,
pre_scale_weight,
dir_path,
basename,
shape,
per_tok_dyn,
per_channel,
is_qkv=False,
rank=None):
suffix = "bin"
if per_channel:
if rank is not None:
suffix = f"{rank}." + suffix
suffix = "col." + suffix
col_shape = shape if (per_channel or is_qkv) else [1, 1]
if per_tok_dyn:
if pre_scale_weight is not None:
pre_scale_weight.value = np.array([1.0], dtype=np.float32)
if is_qkv and not per_channel:
t = fromfile(dir_path,
f"{basename}scale_w_quant_orig.{rank}.{suffix}",
col_shape, np.float32)
else:
t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}",
col_shape, np.float32)
module.per_channel_scale.value = t
else:
t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1],
np.float32)
pre_scale_weight.value = t
if is_qkv:
t = fromfile(dir_path,
f"{basename}scale_y_accum_quant.{rank}.{suffix}",
col_shape, np.float32)
else:
t = fromfile(dir_path,
f"{basename}scale_y_accum_quant.{suffix}",
col_shape, np.float32)
module.per_channel_scale.value = t
t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1],
np.float32)
module.act_scale.value = t
def set_smoother(module, dir_path, base_name, shape, rank):
suffix = f"{rank}.bin"
t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape,
np.float32)
module.smoother.value = t
# Determine the quantization mode.
quant_mode = getattr(tensorrt_llm_internlm, "quant_mode", QuantMode(0))
if quant_mode.is_int8_weight_only():
plugin_weight_only_quant_type = torch.int8
elif quant_mode.is_int4_weight_only():
plugin_weight_only_quant_type = torch.quint4x2
# Do we use SmoothQuant?
use_smooth_quant = quant_mode.has_act_and_weight_quant()
# Do we use quantization per token?
quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
# Do we use quantization per channel?
quant_per_channel = quant_mode.has_per_channel_scaling()
# Do we use INT4/INT8 weight-only?
use_weight_only = quant_mode.is_weight_only()
# Int8 KV cache
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
# Debug
suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel)
# The type of weights.
w_type = np_dtype if not use_smooth_quant else np.int8
if mapping.is_first_pp_rank():
tensorrt_llm_internlm.vocab_embedding.weight.value = (fromfile(
dir_path, 'vocab_embedding.weight.bin', [vocab_size, n_embd]))
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.ln_f.weight.value = (fromfile(
dir_path, 'ln_f.weight.bin'))
# share input embedding
lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin',
[vocab_size, n_embd])
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = tensorrt_llm_internlm.lm_head.out_features * mapping.tp_size
pad_width = vocab_size_padded - vocab_size
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)),
'constant',
constant_values=0)
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.lm_head.weight.value = np.ascontiguousarray(
split(lm_head_weight, mapping.tp_size, mapping.tp_rank))
layers_range = list(
range(mapping.pp_rank * tensorrt_llm_internlm.num_layers,
(mapping.pp_rank + 1) * tensorrt_llm_internlm.num_layers, 1))
for i in layers_range:
n_groups = n_head // n_kv_head
c_attn_out_dim = (
3 * n_embd // mapping.tp_size) if not multi_query_mode else (
n_embd // mapping.tp_size +
(n_embd // n_head * n_groups) // mapping.tp_size * 2)
idx = i - mapping.pp_rank * tensorrt_llm_internlm.num_layers
tensorrt_llm_internlm.layers[idx].input_layernorm.weight.value = (
fromfile(dir_path,
'model.layers.' + str(i) + '.input_layernorm.weight.bin'))
t = fromfile(
dir_path, 'model.layers.' + str(i) +
'.attention.query_key_value.weight.' + suffix,
[n_embd, c_attn_out_dim], w_type)
if t is not None:
dst = tensorrt_llm_internlm.layers[idx].attention.qkv.weight
if use_smooth_quant:
dst.value = (np.ascontiguousarray(np.transpose(t, [1, 0])))
set_smoothquant_scale_factors(
tensorrt_llm_internlm.layers[idx].attention.qkv,
tensorrt_llm_internlm.layers[idx].input_layernorm.
scale_to_int,
dir_path,
'model.layers.' + str(i) + '.attention.query_key_value.',
[1, c_attn_out_dim],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank,
is_qkv=True)
elif use_weight_only:
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
# workaround for trt not supporting int8 inputs in plugins currently
dst.value = processed_torch_weights.view(
dtype=torch.float32).numpy()
scales = tensorrt_llm_internlm.layers[
i].attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dst = tensorrt_llm_internlm.layers[idx].attention.dense.weight
t = fromfile(
dir_path,
'model.layers.' + str(i) + '.attention.dense.weight.' + suffix,
[n_embd // mapping.tp_size, n_embd], w_type)
if use_smooth_quant:
dst.value = (np.ascontiguousarray(np.transpose(t, [1, 0])))
dense_scale = getattr(tensorrt_llm_internlm.layers[idx].attention,
"quantization_scaling_factor", None)
set_smoothquant_scale_factors(
tensorrt_llm_internlm.layers[idx].attention.dense, dense_scale,
dir_path, 'model.layers.' + str(i) + '.attention.dense.',
[1, n_embd], quant_per_token_dyn, quant_per_channel)
set_smoother(tensorrt_llm_internlm.layers[idx].attention.dense,
dir_path,
'model.layers.' + str(i) + '.attention.dense',
[1, n_embd // mapping.tp_size], mapping.tp_rank)
elif use_weight_only:
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[
i].attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
if attn_bias:
dst = tensorrt_llm_internlm.layers[idx].attention.qkv.bias
t = fromfile(
dir_path, 'model.layers.' + str(i) +
f'.attention.query_key_value.bias.{mapping.tp_rank}.bin')
dst.value = np.ascontiguousarray(t)
dst = tensorrt_llm_internlm.layers[idx].attention.dense.bias
t = fromfile(dir_path,
'model.layers.' + str(i) + '.attention.dense.bias.bin')
dst.value = np.ascontiguousarray(t)
dst = tensorrt_llm_internlm.layers[idx].post_layernorm.weight
dst.value = fromfile(
dir_path, 'model.layers.' + str(i) + '.post_layernorm.weight.bin')
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.fc.weight.' + suffix,
[n_embd, inter_size // mapping.tp_size], w_type)
if use_smooth_quant:
tensorrt_llm_internlm.layers[idx].mlp.fc.weight.value = (
np.ascontiguousarray(np.transpose(t, [1, 0])))
set_smoothquant_scale_factors(
tensorrt_llm_internlm.layers[idx].mlp.fc,
tensorrt_llm_internlm.layers[idx].post_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.mlp.fc.',
[1, inter_size // mapping.tp_size],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_internlm.layers[i].mlp.fc.weight
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[i].mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_internlm.layers[
idx].mlp.fc.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.gate.weight.' + suffix,
[n_embd, inter_size // mapping.tp_size], w_type)
if use_smooth_quant:
tensorrt_llm_internlm.layers[idx].mlp.gate.weight.value = (
np.ascontiguousarray(np.transpose(t, [1, 0])))
set_smoothquant_scale_factors(
tensorrt_llm_internlm.layers[idx].mlp.gate,
tensorrt_llm_internlm.layers[idx].post_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.mlp.gate.',
[1, inter_size // mapping.tp_size],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_internlm.layers[i].mlp.gate.weight
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[i].mlp.gate.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_internlm.layers[
idx].mlp.gate.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.proj.weight.' + suffix,
[inter_size // mapping.tp_size, n_embd], w_type)
if use_smooth_quant:
tensorrt_llm_internlm.layers[idx].mlp.proj.weight.value = (
np.ascontiguousarray(np.transpose(t, [1, 0])))
proj_scale = getattr(tensorrt_llm_internlm.layers[idx].mlp,
"quantization_scaling_factor", None)
set_smoothquant_scale_factors(
tensorrt_llm_internlm.layers[idx].mlp.proj, proj_scale,
dir_path, 'model.layers.' + str(i) + '.mlp.proj.', [1, n_embd],
quant_per_token_dyn, quant_per_channel)
set_smoother(tensorrt_llm_internlm.layers[idx].mlp.proj, dir_path,
'model.layers.' + str(i) + '.mlp.proj',
[1, inter_size // mapping.tp_size], mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_internlm.layers[i].mlp.proj.weight
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_internlm.layers[i].mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_internlm.layers[idx].mlp.proj.weight.value = (
np.ascontiguousarray(np.transpose(t, [1, 0])))
if use_int8_kv_cache:
t = fromfile(
dir_path, 'model.layers.' + str(i) +
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
np.float32)
tensorrt_llm_internlm.layers[
idx].attention.kv_orig_quant_scale.value = 1.0 / t
tensorrt_llm_internlm.layers[
idx].attention.kv_quant_orig_scale.value = t
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
def load_from_gptq_internlm(tensorrt_llm_internlm,
quant_ckpt_path,
mapping=Mapping(),
dtype="float16"):
tensorrt_llm.logger.info(
'Loading weights from groupwise GPTQ InternLM safetensors...')
tik = time.time()
if quant_ckpt_path.endswith(".safetensors"):
groupwise_qweight_safetensors = safe_open(quant_ckpt_path,
framework="pt",
device=0)
model_params = {
key: groupwise_qweight_safetensors.get_tensor(key)
for key in groupwise_qweight_safetensors.keys()
}
elif quant_ckpt_path.endswith(".pt"):
model_params = torch.load(quant_ckpt_path,
map_location=torch.device('cpu'))
else:
assert False, "Quantized checkpoint format not supported!"
def unpack_int32_into_int8(w_packed):
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
w_packed_int4x2.shape[1] * 2,
dtype=torch.int8)
w_unpacked[:, ::2] = w_packed_int4x2 % 16
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
return w_unpacked.contiguous()
def preprocess_groupwise_weight_params(weight_name,
qweight_int32=None,
qzeros_int32=None,
scales_fp16=None):
if weight_name is not None:
qweight_int32 = model_params[weight_name].cpu()
qzeros_int32 = model_params[weight_name[:-7] + 'qzeros'].cpu()
scales_fp16 = model_params[weight_name[:-7] + 'scales'].cpu()
UINT4_TO_INT4_FLAG = 1
GPTQ_FLAG = 1
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).T.contiguous() - 8
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2).view(torch.float32)
# zeros = zeros * scales
qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * UINT4_TO_INT4_FLAG -
GPTQ_FLAG) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
# return processed interleaved weight, original scales and zeros * scales
return qweight_interleaved.contiguous(), scales_fp16.contiguous(
), zeros_x_scales_fp16.contiguous()
layer_ids = [
extract_layer_idx(key) for key in groupwise_qweight_safetensors.keys()
]
layer_ids = [
int(layer_idx) for layer_idx in layer_ids if layer_idx is not None
]
num_hidden_layers = max(layer_ids) + 1
num_kv_heads = tensorrt_llm_internlm.num_kv_heads
mha_mode = (num_kv_heads == tensorrt_llm_internlm.num_heads)
suffixs = ['qweight', 'qzeros', 'scales']
layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
for l in layers_range:
prefix = f'model.layers.{l}.self_attn.'
split_qkv_suf = []
for suf in suffixs:
q_part = model_params[prefix + 'q_proj.' + suf].cpu()
k_part = model_params[prefix + 'k_proj.' + suf].cpu()
v_part = model_params[prefix + 'v_proj.' + suf].cpu()
qkv_part = torch.cat([q_part, k_part, v_part], dim=0)
dim = qkv_part.shape
qkv_part = qkv_part.reshape(3, dim[0] // 3, dim[1])
split_qkv = qkv_part.split(dim[1] // mapping.tp_size,
dim=2)[mapping.tp_rank]
split_qkv = torch.cat([
split_qkv[0, :, :].squeeze(0), split_qkv[1, :, :].squeeze(0),
split_qkv[2, :, :].squeeze(0)
],
dim=1)
split_qkv_suf.append(split_qkv)
th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
None, split_qkv_suf[0], split_qkv_suf[1], split_qkv_suf[2])
idx = l - mapping.pp_rank * layers_per_pipeline_stage
tensorrt_llm_internlm.layers[
idx].attention.qkv.qweight.value = th_qweight.numpy()
tensorrt_llm_internlm.layers[
idx].attention.qkv.scale.value = th_zero.numpy()
tensorrt_llm_internlm.layers[
idx].attention.qkv.zero.value = th_scale.numpy()
torch_dtype = str_dtype_to_torch(dtype)
for k, v in model_params.items():
if isinstance(v, list):
v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
else:
v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
if 'model.embed_tokens.weight' in k:
if mapping.is_first_pp_rank():
tensorrt_llm_internlm.vocab_embedding.weight.value = v
elif 'model.norm.weight' in k:
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.ln_f.weight.value = v
elif 'lm_head.weight' in k:
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.lm_head.weight.value = np.ascontiguousarray(
split(v, mapping.tp_size, mapping.tp_rank))
else:
layer_idx = extract_layer_idx(k)
if layer_idx is None:
continue
idx = int(layer_idx)
if idx not in layers_range:
continue
idx = idx - mapping.pp_rank * layers_per_pipeline_stage
if 'input_layernorm.weight' in k:
tensorrt_llm_internlm.layers[
idx].input_layernorm.weight.value = v
elif 'post_attention_layernorm.weight' in k:
tensorrt_llm_internlm.layers[
idx].post_layernorm.weight.value = v
elif 'self_attn.o_proj.qweight' in k:
split_v_suf = []
for suf in suffixs:
v = model_params[k[:-7] + suf].cpu()
split_v = v.split(v.shape[0] // mapping.tp_size,
dim=0)[mapping.tp_rank]
split_v_suf.append(split_v)
th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
None, split_v_suf[0], split_v_suf[1], split_v_suf[2])
tensorrt_llm_internlm.layers[
idx].attention.dense.qweight.value = th_qweight.numpy()
tensorrt_llm_internlm.layers[
idx].attention.dense.scale.value = th_zero.numpy()
tensorrt_llm_internlm.layers[
idx].attention.dense.zero.value = th_scale.numpy()
elif 'mlp.up_proj.qweight' in k:
split_v_suf = []
for suf in suffixs:
v = model_params[k[:-7] + suf].cpu()
split_v = v.split(v.shape[1] // mapping.tp_size,
dim=1)[mapping.tp_rank]
split_v_suf.append(split_v)
th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
None, split_v_suf[0], split_v_suf[1], split_v_suf[2])
tensorrt_llm_internlm.layers[
idx].mlp.gate.qweight.value = th_qweight.numpy()
tensorrt_llm_internlm.layers[
idx].mlp.gate.scale.value = th_zero.numpy()
tensorrt_llm_internlm.layers[
idx].mlp.gate.zero.value = th_scale.numpy()
elif 'mlp.down_proj.qweight' in k:
split_v_suf = []
for suf in suffixs:
v = model_params[k[:-7] + suf].cpu()
split_v = v.split(v.shape[0] // mapping.tp_size,
dim=0)[mapping.tp_rank]
split_v_suf.append(split_v)
th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
None, split_v_suf[0], split_v_suf[1], split_v_suf[2])
tensorrt_llm_internlm.layers[
idx].mlp.proj.qweight.value = th_qweight.numpy()
tensorrt_llm_internlm.layers[
idx].mlp.proj.scale.value = th_zero.numpy()
tensorrt_llm_internlm.layers[
idx].mlp.proj.zero.value = th_scale.numpy()
elif 'mlp.gate_proj.qweight' in k:
split_v_suf = []
for suf in suffixs:
v = model_params[k[:-7] + suf].cpu()
split_v = v.split(v.shape[1] // mapping.tp_size,
dim=1)[mapping.tp_rank]
split_v_suf.append(split_v)
th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
None, split_v_suf[0], split_v_suf[1], split_v_suf[2])
tensorrt_llm_internlm.layers[
idx].mlp.fc.qweight.value = th_qweight.numpy()
tensorrt_llm_internlm.layers[
idx].mlp.fc.scale.value = th_zero.numpy()
tensorrt_llm_internlm.layers[
idx].mlp.fc.zero.value = th_scale.numpy()
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
return
def load_from_awq_internlm(tensorrt_llm_internlm: InternLMForCausalLM,
quant_ckpt_path,
mapping=Mapping(),
dtype="float16"):
tensorrt_llm.logger.info(
'Loading weights from groupwise AWQ InternLM safetensors...')
tik = time.time()
if quant_ckpt_path.endswith(".safetensors"):
groupwise_qweight_safetensors = safe_open(quant_ckpt_path,
framework="pt",
device=0)
awq_internlm = {
key: groupwise_qweight_safetensors.get_tensor(key)
for key in groupwise_qweight_safetensors.keys()
}
elif quant_ckpt_path.endswith(".pt"):
awq_internlm = torch.load(quant_ckpt_path,
map_location=torch.device('cpu'))
else:
assert False, "Quantized checkpoint format not supported!"
group_size = awq_internlm["model.layers.0.self_attn.o_proj.weight"].numel(
) // awq_internlm[
"model.layers.0.self_attn.o_proj.weight_quantizer._amax"].numel()
awq_internlm_block_names = [
"input_layernorm.weight",
"post_attention_layernorm.weight",
]
tensorrt_llm_internlm_block_names = [
"input_layernorm.weight",
"post_layernorm.weight",
]
getattr(tensorrt_llm_internlm, 'quant_mode', QuantMode(0))
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
torch_dtype = str_dtype_to_torch(dtype)
def AWQ_quantize_pack_preprocess(weight, scale):
scale = scale.repeat_interleave(group_size, dim=0)
weight = weight / scale
qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7)
int4_weight = packer(qweight_int8.cpu())
int4_weight = preprocessor(int4_weight, torch.quint4x2)
return int4_weight.view(torch.float32).cpu().numpy()
def process_and_assign_weight(awq_internlm, mPrefix, mOp, tp_dim=0):
weight = awq_internlm[mPrefix + ".weight"].T.contiguous()
[k, n] = weight.shape
weight = weight.split(weight.shape[tp_dim] // mapping.tp_size,
dim=tp_dim)[mapping.tp_rank]
amax = awq_internlm[mPrefix + ".weight_quantizer._amax"].reshape(
(n, int(k / group_size))).T.contiguous()
amax = amax.split(amax.shape[tp_dim] // mapping.tp_size,
dim=tp_dim)[mapping.tp_rank]
pre_quant_scale = awq_internlm[
mPrefix + ".input_quantizer._pre_quant_scale"].reshape((1, k))
if tp_dim == 0:
pre_quant_scale = pre_quant_scale.split(k // mapping.tp_size,
dim=1)[mapping.tp_rank]
scale = amax / 8.0
mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale)
mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
mOp.pre_quant_scale.value = pre_quant_scale.to(
torch_dtype).cpu().numpy()
def deSmooth(weight, pre_quant_scale):
[k, n] = weight.shape
pre_quant_scale = pre_quant_scale.repeat(
(n, 1)).transpose(1, 0).contiguous()
weight = weight * pre_quant_scale
return weight
def reSmooth(weight, pre_quant_scale):
[k, n] = weight.shape
pre_quant_scale = pre_quant_scale.repeat(
(n, 1)).transpose(1, 0).contiguous()
weight = weight / pre_quant_scale
return weight
def get_scale(weight):
weight = weight.T.contiguous()
[n, k] = weight.shape
weight = weight.reshape(n, int(k / group_size), group_size)
weight = torch.abs(weight.reshape(-1, group_size))
amax, idx = weight.max(1)
amax = amax.reshape(n, int(k / group_size)).T.contiguous()
return amax / 8
def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale):
weight = deSmooth(weight, pre_quant_scale)
weight = reSmooth(weight, avg_pre_quant_scale)
scale = get_scale(weight)
return weight, scale
def process_and_assign_qkv_weight(awq_internlm, prefix, mOp):
q_weight = awq_internlm[prefix +
"self_attn.q_proj.weight"].T.contiguous()
k_weight = awq_internlm[prefix +
"self_attn.k_proj.weight"].T.contiguous()
v_weight = awq_internlm[prefix +
"self_attn.v_proj.weight"].T.contiguous()
k = q_weight.shape[0]
q_weight = q_weight.split(q_weight.shape[1] // mapping.tp_size,
dim=1)[mapping.tp_rank]
k_weight = k_weight.split(k_weight.shape[1] // mapping.tp_size,
dim=1)[mapping.tp_rank]
v_weight = v_weight.split(v_weight.shape[1] // mapping.tp_size,
dim=1)[mapping.tp_rank]
q_pre_quant_scale = awq_internlm[
prefix +
"self_attn.q_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
k_pre_quant_scale = awq_internlm[
prefix +
"self_attn.k_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
v_pre_quant_scale = awq_internlm[
prefix +
"self_attn.v_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale +
v_pre_quant_scale) / 3.0
q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale,
qkv_pre_quant_scale)
k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale,
qkv_pre_quant_scale)
v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale,
qkv_pre_quant_scale)
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(
torch_dtype).cpu().numpy()
mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale)
mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
# Check if we need to pad vocab
v = awq_internlm.get('model.embed_tokens.weight')
[vocab_size, k] = v.shape
pad_vocab = False
pad_vocab_size = vocab_size
if vocab_size % 64 != 0:
pad_vocab = True
pad_vocab_size = int((vocab_size + 63) / 64) * 64
if pad_vocab:
new_v = torch.zeros([pad_vocab_size, k])
new_v[:vocab_size, :] = v
v = new_v
if mapping.is_first_pp_rank():
tensorrt_llm_internlm.vocab_embedding.weight.value = v.to(
torch_dtype).cpu().numpy()
layer_ids = [extract_layer_idx(key) for key in awq_internlm.keys()]
layer_ids = [
int(layer_idx) for layer_idx in layer_ids if layer_idx is not None
]
num_hidden_layers = max(layer_ids) + 1
layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
for layer_idx in layers_range:
prefix = "model.layers." + str(layer_idx) + "."
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
for idx, awq_attr in enumerate(awq_internlm_block_names):
v = awq_internlm[prefix + awq_attr]
layer = attrgetter(tensorrt_llm_internlm_block_names[idx])(
tensorrt_llm_internlm.layers[layer_idx])
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
# Attention QKV Linear
# concatenate the Q, K, V layers weights.
process_and_assign_qkv_weight(
awq_internlm, prefix,
tensorrt_llm_internlm.layers[layer_idx].attention.qkv)
# Attention Dense (out_proj) Linear
mPrefix = prefix + "self_attn.o_proj"
mOp = tensorrt_llm_internlm.layers[layer_idx].attention.dense
process_and_assign_weight(awq_internlm, mPrefix, mOp, 0)
# MLP up_proj (mlp.gate) Linear
mPrefix = prefix + "mlp.up_proj"
mOp = tensorrt_llm_internlm.layers[layer_idx].mlp.gate
process_and_assign_weight(awq_internlm, mPrefix, mOp, 1)
# MLP down_proj (mlp.proj) Linear
mPrefix = prefix + "mlp.down_proj"
mOp = tensorrt_llm_internlm.layers[layer_idx].mlp.proj
process_and_assign_weight(awq_internlm, mPrefix, mOp, 0)
# MLP gate_proj (mlp.fc) Linear
mPrefix = prefix + "mlp.gate_proj"
mOp = tensorrt_llm_internlm.layers[layer_idx].mlp.fc
process_and_assign_weight(awq_internlm, mPrefix, mOp, 1)
v = awq_internlm['model.norm.weight']
if mapping.is_last_pp_rank():
tensorrt_llm_internlm.ln_f.weight.value = v.to(
torch_dtype).cpu().numpy()
#lm_head
if pad_vocab:
weight = awq_internlm['lm_head.weight']
[vocab_size, k] = weight.shape
new_weight = torch.zeros([pad_vocab_size, k])
new_weight[:vocab_size, :] = weight
new_weight = new_weight.T.contiguous()
amax = awq_internlm['lm_head.weight_quantizer._amax'].reshape(
[vocab_size, k // group_size])
new_amax = torch.ones([pad_vocab_size, k // group_size])
new_amax[:vocab_size, :] = amax
new_amax = new_amax.T.contiguous()
new_scale = new_amax / 8
tensorrt_llm_internlm.lm_head.qweight.value = AWQ_quantize_pack_preprocess(
new_weight, new_scale)
tensorrt_llm_internlm.lm_head.scale.value = new_scale.to(
torch_dtype).cpu().numpy()
tensorrt_llm_internlm.lm_head.pre_quant_scale.value = awq_internlm[
'lm_head.input_quantizer._pre_quant_scale'].to(
torch_dtype).cpu().numpy()
else:
mPrefix = "lm_head"
mOp = tensorrt_llm_internlm.lm_head
if mapping.is_last_pp_rank():
process_and_assign_weight(awq_internlm, mPrefix, mOp, 1)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')