TensorRT-LLMs/examples/falcon/convert_checkpoint.py
Kaiyu Xie d879430b04
Update TensorRT-LLM (#846)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-09 21:03:35 +08:00

1073 lines
45 KiB
Python

import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Dict, List, Optional, Tuple
import numpy as np
import safetensors
import torch
from transformers import AutoModelForCausalLM, FalconConfig, FalconForCausalLM
import tensorrt_llm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.llama.utils import ( # TODO: move the utils to common dir shared by models
iterate_shard_files, load_state_dict, retrieved_layer_index_from_name)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument('--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'])
parser.add_argument(
'--ammo_quant_ckpt_path',
type=str,
default=None,
help='Path of a quantized model checkpoint in .npz format')
parser.add_argument(
'--per_group',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.')
parser.add_argument('--group_size',
type=int,
default=128,
help='Group size used in GPTQ/AWQ quantization.')
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4', 'int4_awq'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--quantize_lm_head',
default=False,
action="store_true",
help='Quantize lm_head weights as well when using int4_awq.')
parser.add_argument(
'--enable_fp8',
default=False,
action='store_true',
help='Use FP8 Linear layer for Attention QKV/Dense and MLP.')
parser.add_argument(
'--fp8_kv_cache',
default=False,
action="store_true",
help='By default, we use dtype for KV cache. fp8_kv_cache chooses int8 '
'quantization for KV')
parser.add_argument('--load_by_shard',
action='store_true',
help='Load a pretrained model shard-by-shard.')
parser.add_argument('--output_dir',
type=str,
default='tllm_checkpoint',
help='The path to save the TensorRT-LLM checkpoint')
parser.add_argument(
'--workers',
type=int,
default=1,
help='The number of workers for converting checkpoint in parallel')
args = parser.parse_args()
return args
def load_falcon_config(model_dir: str) -> FalconConfig:
""" Helper utility to load FalconConfig.
A pretrained checkpoint from modeling_RW.py has a different structure
and is not compatible with `transformers.FalconConfig` and
`transformers.FalconModel`. We need to manually set the config values.
"""
config = FalconConfig.from_pretrained(model_dir)
config.architectures = ["FalconForCausalLM"]
if config.model_type not in ['RefinedWebModel', 'RefinedWeb']:
return config
if config.model_type == 'RefinedWeb':
# Case 1. Falcon-40B / Falcon-40B-instruct
# https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
config.num_hidden_layers = config.n_layer
config.num_attention_heads = config.n_head
config.num_kv_heads = config.n_head_kv
config.new_decoder_architecture = True
elif config.model_type == 'RefinedWebModel':
# Case 2. Falcon-7B / Falcon-7B-instruct
# https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
config.num_hidden_layers = config.n_layer
config.num_attention_heads = config.n_head
config.num_kv_heads = 1 if config.multi_query else config.n_head
config.new_decoder_architecture = False
else:
raise ValueError("Shouldn't reach here.")
config.model_type = 'falcon'
return config
def split(weight: torch.Tensor,
tp_size: int,
rank: int = 0,
dim: int = 0) -> torch.Tensor:
if tp_size == 1:
return weight
elif weight.ndim == 1:
return torch.chunk(weight, tp_size)[rank].contiguous()
else:
return torch.chunk(weight, tp_size, dim=dim)[rank].contiguous()
def reorder_qkv_weight_or_bias(weight: torch.Tensor,
head_dim: int,
num_heads: int,
num_kv_heads: Optional[int] = None,
tp_size: int = 1,
is_bias: bool = False) -> torch.Tensor:
""" Reorder the qkv weight for TRT-LLM use.
The shape of the fused QKV weights in HF is different from the shape that
TRT-LLM requires. In particular, the weight of HF consists of interleaved
q, k, v head weights, while that of TRT-LLM is contiguous.
HF : [q1, k1, v1, ..., qh, kh, vh]
TRT-LLM: [q1, ..., qh, k1, ..., kh, v1, vh]
where qi, vi, ki are weight vectors corresponding to attention head i.
It's similar to multi/grouped query attention cases.
We reorder and split the weight of an attention layer to fit into TRT-LLM.
The reordered weight and bias will be
weight: (T, Qh * D + 2 * KVh * D, H)
bias : (T, Qh * D + 2 * KVh * D)
where T=tp_size, Qh=local_num_q_heads, KVh=local_num_kv_heads, D=head_dim,
H=hidden_dim. In the multi/grouped query attention, the number of K/V
attention heads are less than that of Q attention, so that K/V attention
heads may be shared across different ranks if necessary.
For tensor parallelism, we use the first dimension to select the
corresponding weights.
"""
# Query types and expected kv heads.
# - Conventional MHA: num_heads = num_kv_heads
# - Multi-Query Attention: num_kv_heads = 1
# - Grouped-Query Attention: num_heads % num_kv_heads = 0
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
assert num_heads % num_kv_heads == 0, \
f'num_heads({num_heads}) must be divisible by '\
f'num_kv_heads({num_kv_heads})).'
# The number of attention heads per group: N q head + 1 k head + 1 v head.
num_group_heads = num_heads // num_kv_heads + 2
assert weight.shape[0] == num_kv_heads * num_group_heads * head_dim, \
f'{weight.shape[0]} != {num_kv_heads} * {num_group_heads} * {head_dim}'
qkv_in = num_heads * head_dim if not is_bias else 1
# Split Q/K/V weights
weight = weight.reshape(num_kv_heads, num_heads // num_kv_heads + 2,
head_dim, qkv_in)
q_w = weight[:, :-2, ...] # (nKV, num_heads // nKV, head_dim, qkv_in)
k_w = weight[:, -2:-1, ...] # (nKV, 1, head_dim, qkv_in)
v_w = weight[:, -1:, ...] # (nKV, 1, head_dim, qkv_in)
if num_kv_heads < num_heads and num_kv_heads < tp_size:
# Duplicate K/V heads to make sure that each rank has at least one
# K/V heads. For instance, num_heads=8, num_kv_heads=2, tp_size=4,
# we will make the qkv weight as below.
# Orig: [q0 q1 q2 q3 k0 v0 q4 q5 q6 q7 k1 v0 v1]
# >>>> [[q0 q1 k0 v0], [q2 q3 k0 v0], [q4 q5 k1 v1], [q6 q7 k1 v1]]
assert tp_size % num_kv_heads == 0
num_dups = tp_size // num_kv_heads
# k_w and v_w have the same shape.
new_shape = (num_kv_heads, num_dups) + k_w.shape[2:]
k_w = torch.broadcast_to(k_w, size=new_shape)
v_w = torch.broadcast_to(v_w, size=new_shape)
# Update the number of kv heads.
num_kv_heads = tp_size
reordered = torch.concat(
[
q_w.reshape(tp_size, num_heads // tp_size, head_dim, qkv_in),
k_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
v_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
],
dim=1,
)
qkv_out = (num_heads + 2 * num_kv_heads) // tp_size * head_dim
return reordered.reshape((tp_size, qkv_out, -1))
def split_qkv_weight(weight: torch.Tensor,
hidden_size: int,
num_heads: int,
tp_size: int,
rank: int,
is_bias: bool,
num_kv_heads: Optional[int] = None) -> torch.Tensor:
""" Splits the QKV matrix according to tensor parallelism """
head_dim = hidden_size // num_heads
weight = reorder_qkv_weight_or_bias(weight,
head_dim=head_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
tp_size=tp_size,
is_bias=is_bias)
# Copy a sliced tensor to prevent memory leak. A sliced tensor shares the
# memory buffer of the original tensor. So, returning without copying makes
# the buffer of a loaded "qkv" be referenced, resulting GC can't release
# those weights until the whole process ends.
if not is_bias:
return weight[rank, ...].clone().contiguous()
else:
return weight[rank, ...].ravel().clone().contiguous()
def split_matrix(weight: torch.Tensor, tp_size: int, rank: int,
dim: int) -> torch.Tensor:
return split(weight, tp_size, rank, dim=dim)
def get_weight(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> torch.Tensor:
if f'{prefix}.weight' not in params:
return None
return params[f'{prefix}.weight'].to(dtype).detach().cpu()
def get_bias(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> torch.Tensor:
if f'{prefix}.bias' not in params:
return None
return params[f'{prefix}.bias'].to(dtype).detach().cpu()
def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> Tuple[torch.Tensor]:
return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype)
def get_tllm_linear_weight(
weight: torch.Tensor,
prefix: str,
bias: Optional[torch.Tensor] = None,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8
) -> Dict[str, torch.Tensor]:
results = {}
if use_weight_only:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[f'{prefix}.weight'] = processed_torch_weights
results[f'{prefix}.per_channel_scale'] = torch_weight_scales
else:
results[f'{prefix}.weight'] = weight.contiguous()
if bias is not None:
results[f'{prefix}.bias'] = bias
return results
def get_tllm_param(
param: torch.Tensor,
name: str,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8
) -> Dict[str, torch.Tensor]:
results = {}
if name.endswith('.weight') and use_weight_only:
v = param.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[name] = processed_torch_weights
results[name.replace('weight',
'per_channel_scale')] = torch_weight_scales
else:
results[name] = param
return results
def convert_hf_falcon(hf_model: FalconForCausalLM,
hf_config: FalconConfig,
mapping: Mapping,
dtype: str = 'float32',
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8):
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_config.num_attention_heads
hidden_size = hf_config.hidden_size
num_kv_heads = getattr(hf_config, 'num_kv_heads', num_attention_heads)
num_hidden_layers = hf_config.num_hidden_layers
parallel_attention = hf_config.parallel_attn
new_decoder_architecture = hf_config.new_decoder_architecture
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
prefix = f'transformer.h.{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
qkv_weight, qkv_bias = get_weight_and_bias(
model_params, f'{prefix}.self_attention.query_key_value', dtype)
qkv_w = split_qkv_weight(qkv_weight,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=num_kv_heads)
if qkv_bias is None:
qkv_b = None
else:
qkv_b = split_qkv_weight(qkv_bias,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=num_kv_heads)
weights.update(
get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', qkv_b,
use_weight_only,
plugin_weight_only_quant_type))
attn_dense_weight, attn_dense_bias = get_weight_and_bias(
model_params, f'{prefix}.self_attention.dense', dtype)
attn_dense_w = split_matrix(attn_dense_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(attn_dense_w, f'{tllm_prex}.attention.dense',
attn_dense_bias, use_weight_only,
plugin_weight_only_quant_type))
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
model_params, f'{prefix}.mlp.dense_h_to_4h', dtype)
mlp_fc_w = split_matrix(mlp_fc_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
if mlp_fc_bias is None:
mlp_fc_b = None
else:
mlp_fc_b = split_matrix(mlp_fc_bias,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', mlp_fc_b,
use_weight_only,
plugin_weight_only_quant_type))
mlp_proj_weight, mlp_proj_bias = get_weight_and_bias(
model_params, f'{prefix}.mlp.dense_4h_to_h', dtype)
mlp_proj_w = split_matrix(mlp_proj_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj',
mlp_proj_bias, use_weight_only,
plugin_weight_only_quant_type))
if new_decoder_architecture:
input_ln_weight, input_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.ln_attn', dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_weight
if input_ln_bias is not None:
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_bias
mlp_ln_weight, mlp_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.ln_mlp', dtype)
weights[f'{tllm_prex}.mlp_layernorm.weight'] = mlp_ln_weight
if mlp_ln_bias is not None:
weights[f'{tllm_prex}.mlp_layernorm.bias'] = mlp_ln_bias
else:
input_ln_weight, input_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.input_layernorm', dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_weight
if input_ln_bias is not None:
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_bias
if not parallel_attention:
post_ln_weight, post_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.post_attention_layernorm', dtype)
if post_ln_weight is not None:
weights[
f'{tllm_prex}.post_layernorm.weight'] = post_ln_weight
if post_ln_bias is not None:
weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_bias
embed_w = get_weight(model_params, 'transformer.word_embeddings', dtype)
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = embed_w
if mapping.is_last_pp_rank():
weights['lm_head.weight'] = split_matrix(embed_w.clone(),
mapping.tp_size,
mapping.tp_rank,
dim=0)
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f',
dtype)
weights['transformer.ln_f.weight'] = ln_f_w
if ln_f_b is not None:
weights['transformer.ln_f.bias'] = ln_f_b
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def load_from_hf_falcon_checkpoint(
hf_model_dir: str,
hf_config: FalconConfig,
mapping: Mapping,
dtype: str = 'float32',
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8):
weights = {}
tik = time.time()
dtype = getattr(torch, dtype)
num_attention_heads = hf_config.num_attention_heads
hidden_size = hf_config.hidden_size
num_kv_heads = getattr(hf_config, 'num_kv_heads', num_attention_heads)
num_hidden_layers = hf_config.num_hidden_layers
layers_range = mapping.pp_layers(num_hidden_layers)
for model_file in iterate_shard_files(hf_model_dir, mapping.tp_rank):
state_dict = load_state_dict(model_file, dtype)
for name, param in state_dict.items():
param = param.detach().cpu()
l = retrieved_layer_index_from_name(name)
if l is not None:
if l not in layers_range:
continue
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
if 'self_attention.query_key_value' in name:
if name.endswith('weight'):
qkv_w = split_qkv_weight(param,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=num_kv_heads)
weights.update(
get_tllm_param(qkv_w,
f'{tllm_prex}.attention.qkv.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
qkv_b = split_qkv_weight(param,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=num_kv_heads)
weights.update(
get_tllm_param(qkv_b,
f'{tllm_prex}.attention.qkv.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'self_attention.dense' in name:
if name.endswith('weight'):
attn_dense_w = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_param(
attn_dense_w,
f'{tllm_prex}.attention.dense.weight',
use_weight_only, plugin_weight_only_quant_type))
else:
weights.update(
get_tllm_param(param,
f'{tllm_prex}.attention.dense.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'mlp.dense_h_to_4h' in name:
if name.endswith('weight'):
mlp_fc_w = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_param(mlp_fc_w,
f'{tllm_prex}.mlp.fc.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
mlp_fc_b = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_param(mlp_fc_b, f'{tllm_prex}.mlp.fc.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'mlp.dense_4h_to_h' in name:
if name.endswith('weight'):
mlp_proj_w = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_param(mlp_proj_w,
f'{tllm_prex}.mlp.proj.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
weights.update(
get_tllm_param(param, f'{tllm_prex}.mlp.proj.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'ln_attn' in name or 'input_layernorm' in name:
if name.endswith('weight'):
weights[f'{tllm_prex}.input_layernorm.weight'] = param
else:
weights[f'{tllm_prex}.input_layernorm.bias'] = param
elif 'ln_mlp' in name:
if name.endswith('weight'):
weights[f'{tllm_prex}.mlp_layernorm.weight'] = param
else:
weights[f'{tllm_prex}.mlp_layernorm.bias'] = param
elif 'post_attention_layernorm' in name:
if name.endswith('weight'):
weights[f'{tllm_prex}.post_layernorm.weight'] = param
else:
weights[f'{tllm_prex}.post_layernorm.bias'] = param
else:
if 'word_embeddings' in name:
if mapping.is_first_pp_rank():
weights['transformer.vocab_embedding.weight'] = param
if mapping.is_last_pp_rank():
weights['lm_head.weight'] = split_matrix(
param.clone(),
mapping.tp_size,
mapping.tp_rank,
dim=0)
elif 'ln_f' in name:
if mapping.is_last_pp_rank():
if name.endswith('weight'):
weights['transformer.ln_f.weight'] = param
else:
weights['transformer.ln_f.bias'] = param
del state_dict
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def load_from_awq_falcon(quant_ckpt_path: str,
hf_config: FalconConfig,
mapping: Mapping,
quantize_lm_head: bool = False,
dtype: str = "float16"):
weights = {}
tik = time.time()
num_hidden_layers = hf_config.num_hidden_layers
parallel_attention = hf_config.parallel_attn
new_decoder_architecture = hf_config.new_decoder_architecture
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
torch_dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype)
if not quant_ckpt_path.endswith(".npz"):
raise ValueError("Unsupported AWQ quantized checkpoint format")
awq_falcon = np.load(quant_ckpt_path)
awq_prefix = "_np:"
awq_suffix_list = [
":weight",
":weights_scaling_factor",
":prequant_scaling_factor",
]
awq_key_list = [
"vocab_embedding:weight", # embedding
"lm_head", # lm_head
"final_layernorm", # ln_f
"attention:qkv:", # attention.qkv
"attention:dense", # attention.dense
"mlp:proj", # mlp.proj
"mlp:fc", # mlp.fc
"input_layernorm", # input_layernorm.weight
"mlp_layernorm", # mlp_layernorm.weight
]
split_sym = ":"
AMMO_WEIGHT_SCALING_FACTOR_COEFF = 7
def load(key):
if awq_prefix + key not in awq_falcon:
return None
v = torch.from_numpy(awq_falcon[awq_prefix + key]).to(torch_dtype)
if "weights_scaling_factor" in key:
v *= AMMO_WEIGHT_SCALING_FACTOR_COEFF # For AMMO *.npz checkpoints
return v
group_size = load("layers:0:attention:dense:weight").numel() // load(
"layers:0:attention:dense:weights_scaling_factor").numel()
def torch_split(v, dim):
if v.shape[dim] % mapping.tp_size != 0:
tensorrt_llm.logger.error(
"Current weight shape is invalid for mapping.tp_size=" +
str(mapping.tp_size))
raise ValueError("Invalid TP size")
return v.split(v.shape[dim] // mapping.tp_size,
dim=dim)[mapping.tp_rank]
def AWQ_quantize_pack_preprocess(weight, scale):
weight /= scale.repeat_interleave(group_size, dim=0)
qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7)
int4_weight = preprocessor(packer(qweight_int8.cpu()), torch.quint4x2)
return int4_weight.view(torch.float16)
def get_tllm_weight_from_awq(v: List[torch.Tensor],
tllm_prex: str,
tp_dim: int = 0):
weight = v[0].T.contiguous()
[k, n] = weight.shape
weight = torch_split(weight, tp_dim)
amax = v[1].reshape((n, k // group_size)).T.contiguous()
amax = torch_split(amax, tp_dim)
pre_quant_scale = v[2].reshape((1, k))
if tp_dim == 0:
pre_quant_scale = torch_split(pre_quant_scale, 1)
scale = amax / 8.0
results = {
f'{tllm_prex}.weight': AWQ_quantize_pack_preprocess(weight, scale),
f'{tllm_prex}.weights_scaling_factor': scale.to(torch_dtype),
f'{tllm_prex}.prequant_scaling_factor':
pre_quant_scale.to(torch_dtype),
}
return results
def get_scale(weight):
[k, n] = weight.shape
weight_t = weight.T.contiguous()
weight_t = weight_t.reshape(n, k // group_size, group_size)
weight_t = torch.abs(weight_t.reshape(-1, group_size))
amax, idx = weight_t.max(1)
amax = amax.reshape(n, k // group_size).T.contiguous()
scale = amax / 8
return scale
def get_tllm_qkv_weight_from_awq(prefix, tllm_prex: str):
q_weight = load(prefix + "q" + awq_suffix_list[0]).T.contiguous()
k_weight = load(prefix + "k" + awq_suffix_list[0]).T.contiguous()
v_weight = load(prefix + "v" + awq_suffix_list[0]).T.contiguous()
dim_k = q_weight.shape[0]
q_weight = torch_split(q_weight, 1)
k_weight = torch_split(k_weight, 1)
v_weight = torch_split(v_weight, 1)
qkv_pre_quant_scale = load(prefix + "q" + awq_suffix_list[2]).reshape(
(1, dim_k))
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
qkv_scale = get_scale(qkv_weights)
results = {
f'{tllm_prex}.prequant_scaling_factor':
qkv_pre_quant_scale.to(torch_dtype),
f'{tllm_prex}.weight':
AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale),
f'{tllm_prex}.weights_scaling_factor':
qkv_scale.to(torch_dtype),
}
return results
# Load weights from AWQ checkpoint into TRT-LLM module
# 1. embedding
v = load(awq_key_list[0])
# TRT-LLM requires vocab_size to be multiple of 64 for successful GEMM
if quantize_lm_head and v.shape[0] % 64 != 0:
v = torch.nn.functional.pad(v, [0, 0, 0, 64 - v.shape[0] % 64])
if mapping.is_first_pp_rank():
# tensorrt_llm_falcon.embedding.weight.value = v.to(torch_dtype)
weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype)
# 2. lm_head
if mapping.is_last_pp_rank():
if quantize_lm_head:
v = [load(awq_key_list[1] + suf) for suf in awq_suffix_list]
if v[0].shape[0] % 64 != 0:
v[0] = torch.nn.functional.pad(
v[0], [0, 0, 0, 64 - v[0].shape[0] % 64])
v[1] = torch.nn.functional.pad(
v[1], [0, 0, 0, 64 - v[1].shape[0] % 64], value=1)
# process_and_assign_weight(tensorrt_llm_falcon.lm_head, v, 1)
weights.update(get_tllm_weight_from_awq(v, 'lm_head', 1))
else:
v = load(awq_key_list[1] + awq_suffix_list[0])
weights['lm_head.weight'] = torch_split(v.to(torch_dtype), 0)
# 3. ln_f
v_weight = load(awq_key_list[2] + split_sym + "weight")
v_bias = load(awq_key_list[2] + split_sym + "bias")
if mapping.is_last_pp_rank():
# tensorrt_llm_falcon.ln_f.weight.value = v_weight.to(torch_dtype)
# tensorrt_llm_falcon.ln_f.bias.value = v_bias.to(torch_dtype)
weights['transformer.ln_f.weight'] = v_weight.to(torch_dtype)
weights['transformer.ln_f.bias'] = v_bias.to(torch_dtype)
# 4. Weights inside each layer
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
# layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage
# prefix = "layers" + split_sym + str(layer_idx) + split_sym
# tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
# layer = tensorrt_llm_falcon.layers[layer_idx]
prefix = f'layers{split_sym}{l}{split_sym}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
# 4.1 attention.qkv
# process_and_assign_qkv_weight(prefix + awq_key_list[3],
# layer.attention.qkv)
weights.update(
get_tllm_qkv_weight_from_awq(prefix + awq_key_list[3],
f'{tllm_prex}.attention.qkv'))
q_b = load(prefix + awq_key_list[3] + 'q:bias')
k_b = load(prefix + awq_key_list[3] + 'k:bias')
v_b = load(prefix + awq_key_list[3] + 'v:bias')
if q_b is not None:
q_b = torch_split(q_b, dim=0)
k_b = torch_split(k_b, dim=0)
v_b = torch_split(v_b, dim=0)
qkv_b = torch.cat((q_b, k_b, v_b), dim=0)
weights[f'{tllm_prex}.attention.qkv.bias'] = qkv_b
# 4.2 attention.dense
v = [load(prefix + awq_key_list[4] + suf) for suf in awq_suffix_list]
# process_and_assign_weight(layer.attention.dense, v, 0)
weights.update(
get_tllm_weight_from_awq(v,
f'{tllm_prex}.attention.dense',
tp_dim=0))
b = load(prefix + awq_key_list[4] + ':bias')
if b is not None:
if mapping.tp_rank > 0:
b = torch.zeros_like(b)
weights[f'{tllm_prex}.attention.dense.bias'] = b
# 4.4 mlp.fc
v = [load(prefix + awq_key_list[6] + suf) for suf in awq_suffix_list]
# process_and_assign_weight(layer.mlp.fc, v, 1)
weights.update(
get_tllm_weight_from_awq(v, f'{tllm_prex}.mlp.fc', tp_dim=1))
b = load(prefix + awq_key_list[6] + ':bias')
if b is not None:
b = torch_split(b, dim=0)
weights[f'{tllm_prex}.mlp.fc.bias'] = b
# 4.3 mlp.proj
v = [load(prefix + awq_key_list[5] + suf) for suf in awq_suffix_list]
# process_and_assign_weight(layer.mlp.proj, v, 0)
weights.update(
get_tllm_weight_from_awq(v, f'{tllm_prex}.mlp.proj', tp_dim=0))
b = load(prefix + awq_key_list[5] + ':bias')
if b is not None:
if mapping.tp_rank > 0:
b = torch.zeros_like(b)
weights[f'{tllm_prex}.mlp.proj.bias'] = b
if new_decoder_architecture:
# 4.5 input_layernorm
v = load(prefix + awq_key_list[7] + split_sym + "weight")
# layer.input_layernorm.weight.value = v.to(torch_dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype)
v = load(prefix + awq_key_list[7] + split_sym + "bias")
# layer.input_layernorm.bias.value = v.to(torch_dtype)
weights[f'{tllm_prex}.input_layernorm.bias'] = v.to(torch_dtype)
# 4.6 mlp_layernorm
v = load(prefix + awq_key_list[8] + split_sym + "weight")
# layer.mlp_layernorm.weight.value = v.to(torch_dtype)
weights[f'{tllm_prex}.mlp_layernorm.weight'] = v.to(torch_dtype)
v = load(prefix + awq_key_list[8] + split_sym + "bias")
# layer.mlp_layernorm.bias.value = v.to(torch_dtype)
weights[f'{tllm_prex}.mlp_layernorm.bias'] = v.to(torch_dtype)
else:
# 4.5 input_layernorm
v = load(prefix + awq_key_list[7] + split_sym + "weight")
# layer.input_layernorm.weight.value = v.to(torch_dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype)
v = load(prefix + awq_key_list[7] + split_sym + "bias")
# layer.input_layernorm.bias.value = v.to(torch_dtype)
weights[f'{tllm_prex}.input_layernorm.bias'] = v.to(torch_dtype)
if not parallel_attention:
# 4.6 post_layernorm
v = load(prefix + 'post_layernorm' + split_sym + "weight")
weights[f'{tllm_prex}.post_layernorm.weight'] = v.to(
torch_dtype)
v = load(prefix + 'post_layernorm' + split_sym + "bias")
weights[f'{tllm_prex}.post_layernorm.bias'] = v.to(torch_dtype)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Elapsed time: {t}')
return weights
def load_from_fp8_falcon(quant_ckpt_path: str, hf_config: FalconConfig,
mapping: Mapping, fp8_kv_cache: bool):
"""
Get the fp8 scaling factors for Falcon model.
"""
fake_fp8_sf_dt = torch.float32
fp8_falcon = np.load(quant_ckpt_path)
weights = {}
num_hidden_layers = hf_config.num_hidden_layers
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
prefix = f'_np:layers:{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
weights[f'{tllm_prex}.attention.qkv.activation_scaling_factor'] = torch.tensor(
[
max(
fp8_falcon[
f'{prefix}:attention:qkv:q:activation_scaling_factor'].
item(), fp8_falcon[
f'{prefix}:attention:qkv:k:activation_scaling_factor'].
item(), fp8_falcon[
f'{prefix}:attention:qkv:v:activation_scaling_factor'].
item())
],
dtype=fake_fp8_sf_dt)
weights[
f'{tllm_prex}.attention.qkv.weights_scaling_factor'] = torch.tensor(
[
max(
fp8_falcon[
f'{prefix}:attention:qkv:q:weights_scaling_factor'].
item(), fp8_falcon[
f'{prefix}:attention:qkv:k:weights_scaling_factor'].
item(), fp8_falcon[
f'{prefix}:attention:qkv:v:weights_scaling_factor'].
item())
],
dtype=fake_fp8_sf_dt)
weights[
f'{tllm_prex}.attention.dense.activation_scaling_factor'] = torch.tensor(
[
fp8_falcon[
f'{prefix}:attention:dense:activation_scaling_factor'].
item()
],
dtype=fake_fp8_sf_dt)
weights[
f'{tllm_prex}.attention.dense.weights_scaling_factor'] = torch.tensor(
[
fp8_falcon[
f'{prefix}:attention:dense:weights_scaling_factor'].
item()
],
dtype=fake_fp8_sf_dt)
weights[f'{tllm_prex}.mlp.fc.activation_scaling_factor'] = torch.tensor(
[fp8_falcon[f'{prefix}:mlp:fc:activation_scaling_factor'].item()],
dtype=fake_fp8_sf_dt)
weights[f'{tllm_prex}.mlp.fc.weights_scaling_factor'] = torch.tensor(
[fp8_falcon[f'{prefix}:mlp:fc:weights_scaling_factor'].item()],
dtype=fake_fp8_sf_dt)
weights[
f'{tllm_prex}.mlp.proj.activation_scaling_factor'] = torch.tensor(
[
fp8_falcon[f'{prefix}:mlp:proj:activation_scaling_factor'].
item()
],
dtype=fake_fp8_sf_dt)
weights[f'{tllm_prex}.mlp.proj.weights_scaling_factor'] = torch.tensor(
[fp8_falcon[f'{prefix}:mlp:proj:weights_scaling_factor'].item()],
dtype=fake_fp8_sf_dt)
if fp8_kv_cache:
# Not calibrarting KV cache.
scaling_factor = 1.0
weights[
f'{tllm_prex}.attention.kv_orig_quant_scale'] = torch.tensor(
[scaling_factor], dtype=fake_fp8_sf_dt)
weights[
f'{tllm_prex}.attention.kv_quant_orig_scale'] = torch.tensor(
[1.0 / scaling_factor], dtype=fake_fp8_sf_dt)
return weights
if __name__ == '__main__':
# TODO(qijun): Currently, the convert script depends on a torch op:
# torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix,
# which is included in tensorrt_llm Python package. Otherwise, the convert
# script does not need to import tensorrt_llm. Will remove it after reimplementing
# the op with PyTorch.
print(tensorrt_llm.__version__)
args = parse_arguments()
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
hf_config = load_falcon_config(args.model_dir)
config = {
'architecture': hf_config.architectures[0],
'dtype': args.dtype,
'logits_dtype': args.logits_dtype,
'num_hidden_layers': hf_config.num_hidden_layers,
'num_attention_heads': hf_config.num_attention_heads,
'num_key_value_heads': hf_config.num_kv_heads,
'hidden_size': hf_config.hidden_size,
'norm_epsilon': hf_config.layer_norm_epsilon,
'vocab_size': hf_config.vocab_size,
'position_embedding_type':
'alibi_with_scale' if hf_config.alibi else 'rope_gpt_neox',
'max_position_embeddings': hf_config.max_position_embeddings,
'hidden_act': 'gelu',
'quantization': {
'use_weight_only': args.use_weight_only,
'weight_only_precision': args.weight_only_precision,
'per_group': args.per_group,
'group_size': args.group_size,
'enable_fp8': args.enable_fp8,
'fp8_kv_cache': args.fp8_kv_cache,
},
'mapping': {
'world_size': args.world_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
},
'bias': hf_config.bias,
'parallel_attention': hf_config.parallel_attn,
'new_decoder_architecture': hf_config.new_decoder_architecture,
}
if args.weight_only_precision == 'int4_awq':
exclude_modules = ['lm_head'] if not args.quantize_lm_head else []
config['quantization'].update({
'zero': False,
'pre_quant_scale': True,
'exclude_modules': exclude_modules,
})
if args.quantize_lm_head and config['vocab_size'] % 64 != 0:
config['vocab_size'] = int((config['vocab_size'] + 63) / 64) * 64
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
if args.weight_only_precision == 'int8':
plugin_weight_only_quant_type = torch.int8
elif args.weight_only_precision == 'int4':
plugin_weight_only_quant_type = torch.quint4x2
def covert_and_save(rank):
mapping = Mapping(world_size=args.tp_size * args.pp_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size)
if args.use_weight_only and args.weight_only_precision == 'int4_awq':
weights = load_from_awq_falcon(
args.ammo_quant_ckpt_path,
hf_config,
mapping,
quantize_lm_head=args.quantize_lm_head,
dtype=args.dtype)
else:
if args.load_by_shard:
weights = load_from_hf_falcon_checkpoint(
args.model_dir,
hf_config,
mapping,
dtype=args.dtype,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
else:
hf_model = AutoModelForCausalLM.from_pretrained(
args.model_dir, trust_remote_code=True, torch_dtype="auto")
weights = convert_hf_falcon(
hf_model,
hf_config,
mapping,
dtype=args.dtype,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
del hf_model
if args.enable_fp8 or args.fp8_kv_cache:
scales = load_from_fp8_falcon(args.ammo_quant_ckpt_path,
hf_config, mapping,
args.fp8_kv_cache)
weights.update(scales)
safetensors.torch.save_file(
weights, os.path.join(args.output_dir, f'rank{rank}.safetensors'))
if args.workers == 1:
for rank in range(args.world_size):
covert_and_save(rank)
else:
with ThreadPoolExecutor(max_workers=args.workers) as p:
futures = [
p.submit(covert_and_save, rank)
for rank in range(args.world_size)
]
wait(futures)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')